diff --git a/pom.xml b/pom.xml index ce7cbbc5f5c..b4d92e9161c 100644 --- a/pom.xml +++ b/pom.xml @@ -216,7 +216,7 @@ 23.4.0.24.05 42.7.2 - 4.18.1 + 4.18.2-SNAPSHOT 8.13.3 2.0.9 2.16.1 diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 3ad7f5a916d..d7603732a25 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -30,15 +30,19 @@ import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.ResultSet; import com.datastax.oss.driver.api.core.cql.Row; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal; import com.datastax.oss.driver.api.querybuilder.delete.Delete; import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.datastax.oss.driver.api.querybuilder.select.Selector; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; @@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; - private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?"; - private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class); private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE, @@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme private final PreparedStatement deleteStmt; - private final String similarityStmt; - private final Similarity similarity; private final BatchingStrategy batchingStrategy; @@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe .get(); this.similarity = getIndexSimilarity(cassandraMetadata); - this.similarityStmt = similaritySearchStatement(); this.filterExpressionConverter = new CassandraFilterExpressionConverter( cassandraMetadata.getColumns().values()); @@ -232,21 +231,14 @@ public List doSimilaritySearch(SearchRequest request) { Preconditions.checkArgument(request.getTopK() <= 1000); var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); CqlVector cqlVector = CqlVector.newInstance(embedding); - - String whereClause = ""; - if (request.hasFilterExpression()) { - String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); - if (!expression.isBlank()) { - whereClause = String.format("where %s", expression); - } - } - - String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK()); + String cql = createSimilaritySearchCql(request, cqlVector, request.getTopK()); List documents = new ArrayList<>(); - logger.trace("Executing {}", query); - SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH); + logger.trace("Executing {}", cql); - for (Row row : this.conf.session.execute(s)) { + ResultSet result = this.conf.session + .execute(SimpleStatement.newInstance(cql).setExecutionProfileName(DRIVER_PROFILE_SEARCH)); + + for (Row row : result) { float score = row.getFloat(0); if (score < request.getSimilarityThreshold()) { break; @@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set metadataFields) { }); } - private String similaritySearchStatement() { - StringBuilder ids = new StringBuilder(); - for (var m : this.conf.schema.partitionKeys()) { - ids.append(m.name()).append(','); - } - for (var m : this.conf.schema.clusteringKeys()) { - ids.append(m.name()).append(','); - } - ids.deleteCharAt(ids.length() - 1); + private String createSimilaritySearchCql(SearchRequest request, CqlVector cqlVector, int topK) { - String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase()) - .append('(') - .append(this.conf.schema.embedding()) - .append(",?)") - .toString(); + Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()) + .function("similarity_" + this.similarity.toString().toLowerCase(), + Selector.column(this.conf.schema.embedding()), literal(cqlVector)); - StringBuilder extraSelectFields = new StringBuilder(); + for (var c : this.conf.schema.partitionKeys()) { + stmt = stmt.column(c.name()); + } + for (var c : this.conf.schema.clusteringKeys()) { + stmt = stmt.column(c.name()); + } + stmt = stmt.column(this.conf.schema.content()); for (var m : this.conf.schema.metadataColumns()) { - extraSelectFields.append(',').append(m.name()); + stmt = stmt.column(m.name()); } if (this.conf.returnEmbeddings) { - extraSelectFields.append(',').append(this.conf.schema.embedding()); + stmt = stmt.column(this.conf.schema.embedding()); } - // java-driver-query-builder doesn't support orderByAnnOf yet - String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(), - extraSelectFields.toString(), this.conf.schema.keyspace(), this.conf.schema.table(), - this.conf.schema.embedding()); - - query = query.replace("?", "%s"); - logger.debug("preparing {}", query); - return query; + // the filterExpression is a string so we go back to building a CQL string + String whereClause = ""; + if (request.hasFilterExpression()) { + String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); + if (!expression.isBlank()) { + whereClause = String.format("WHERE %s", expression); + } + } + String cql = stmt.orderByAnnOf(this.conf.schema.embedding(), cqlVector).limit(topK).asCql(); + return cql.replace(" ORDER ", whereClause + " ORDER "); } private String getDocumentId(Row row) { diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index 007d2c08c3a..75c2eb866dd 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -36,7 +36,6 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; import com.datastax.oss.driver.api.core.type.reflect.GenericType; -import com.datastax.oss.driver.api.querybuilder.BuildableQuery; import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn; import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd; @@ -234,25 +233,15 @@ private void ensureTableExists(int vectorDimension) { createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); } - createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT); + createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT) + .withColumn(this.schema.embedding, DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension)); for (SchemaColumn metadata : this.schema.metadataColumns) { createTable = createTable.withColumn(metadata.name(), metadata.type()); } - // https://datastax-oss.atlassian.net/browse/JAVA-3118 - // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT, - // vectorDimension)); - - StringBuilder tableStmt = new StringBuilder(createTable.asCql()); - tableStmt.setLength(tableStmt.length() - 1); - tableStmt.append(',') - .append(this.schema.embedding) - .append(" vector)"); - logger.debug("Executing {}", tableStmt.toString()); - this.session.execute(tableStmt.toString()); + logger.debug("Executing {}", createTable.asCql()); + this.session.execute(createTable.build()); } } @@ -290,28 +279,12 @@ private void ensureTableColumnsExist(int vectorDimension) { alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT); } if (addEmbedding) { - // special case for embedding column, bc JAVA-3118, as above - StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql()); - if (newColumns.isEmpty() && !addContent) { - alterTableStmt.append(" ADD ("); - } - else { - alterTableStmt.setLength(alterTableStmt.length() - 1); - alterTableStmt.append(','); - } - alterTableStmt.append(this.schema.embedding) - .append(" vector)"); - - logger.debug("Executing {}", alterTableStmt.toString()); - this.session.execute(alterTableStmt.toString()); - } - else { - SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); - logger.debug("Executing {}", stmt.getQuery()); - this.session.execute(stmt); + alterTable = alterTable.addColumn(this.schema.embedding, + DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension)); } + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + this.session.execute(stmt); } } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java index 9bfd6eb2060..242c467c98a 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java @@ -23,7 +23,7 @@ */ public final class CassandraImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0.2"); private CassandraImage() {