diff --git a/pom.xml b/pom.xml
index ce7cbbc5f5c..b4d92e9161c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -216,7 +216,7 @@
23.4.0.24.05
42.7.2
- 4.18.1
+ 4.18.2-SNAPSHOT
8.13.3
2.0.9
2.16.1
diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
index 3ad7f5a916d..d7603732a25 100644
--- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
+++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java
@@ -30,15 +30,19 @@
import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
+import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
+import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
+import com.datastax.oss.driver.api.querybuilder.select.Select;
+import com.datastax.oss.driver.api.querybuilder.select.Selector;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
@@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
- private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";
-
private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class);
private static Map SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE,
@@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
private final PreparedStatement deleteStmt;
- private final String similarityStmt;
-
private final Similarity similarity;
private final BatchingStrategy batchingStrategy;
@@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe
.get();
this.similarity = getIndexSimilarity(cassandraMetadata);
- this.similarityStmt = similaritySearchStatement();
this.filterExpressionConverter = new CassandraFilterExpressionConverter(
cassandraMetadata.getColumns().values());
@@ -232,21 +231,14 @@ public List doSimilaritySearch(SearchRequest request) {
Preconditions.checkArgument(request.getTopK() <= 1000);
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
CqlVector cqlVector = CqlVector.newInstance(embedding);
-
- String whereClause = "";
- if (request.hasFilterExpression()) {
- String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
- if (!expression.isBlank()) {
- whereClause = String.format("where %s", expression);
- }
- }
-
- String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK());
+ String cql = createSimilaritySearchCql(request, cqlVector, request.getTopK());
List documents = new ArrayList<>();
- logger.trace("Executing {}", query);
- SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH);
+ logger.trace("Executing {}", cql);
- for (Row row : this.conf.session.execute(s)) {
+ ResultSet result = this.conf.session
+ .execute(SimpleStatement.newInstance(cql).setExecutionProfileName(DRIVER_PROFILE_SEARCH));
+
+ for (Row row : result) {
float score = row.getFloat(0);
if (score < request.getSimilarityThreshold()) {
break;
@@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set metadataFields) {
});
}
- private String similaritySearchStatement() {
- StringBuilder ids = new StringBuilder();
- for (var m : this.conf.schema.partitionKeys()) {
- ids.append(m.name()).append(',');
- }
- for (var m : this.conf.schema.clusteringKeys()) {
- ids.append(m.name()).append(',');
- }
- ids.deleteCharAt(ids.length() - 1);
+ private String createSimilaritySearchCql(SearchRequest request, CqlVector cqlVector, int topK) {
- String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase())
- .append('(')
- .append(this.conf.schema.embedding())
- .append(",?)")
- .toString();
+ Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table())
+ .function("similarity_" + this.similarity.toString().toLowerCase(),
+ Selector.column(this.conf.schema.embedding()), literal(cqlVector));
- StringBuilder extraSelectFields = new StringBuilder();
+ for (var c : this.conf.schema.partitionKeys()) {
+ stmt = stmt.column(c.name());
+ }
+ for (var c : this.conf.schema.clusteringKeys()) {
+ stmt = stmt.column(c.name());
+ }
+ stmt = stmt.column(this.conf.schema.content());
for (var m : this.conf.schema.metadataColumns()) {
- extraSelectFields.append(',').append(m.name());
+ stmt = stmt.column(m.name());
}
if (this.conf.returnEmbeddings) {
- extraSelectFields.append(',').append(this.conf.schema.embedding());
+ stmt = stmt.column(this.conf.schema.embedding());
}
- // java-driver-query-builder doesn't support orderByAnnOf yet
- String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(),
- extraSelectFields.toString(), this.conf.schema.keyspace(), this.conf.schema.table(),
- this.conf.schema.embedding());
-
- query = query.replace("?", "%s");
- logger.debug("preparing {}", query);
- return query;
+ // the filterExpression is a string so we go back to building a CQL string
+ String whereClause = "";
+ if (request.hasFilterExpression()) {
+ String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
+ if (!expression.isBlank()) {
+ whereClause = String.format("WHERE %s", expression);
+ }
+ }
+ String cql = stmt.orderByAnnOf(this.conf.schema.embedding(), cqlVector).limit(topK).asCql();
+ return cql.replace(" ORDER ", whereClause + " ORDER ");
}
private String getDocumentId(Row row) {
diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java
index 007d2c08c3a..75c2eb866dd 100644
--- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java
+++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java
@@ -36,7 +36,6 @@
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
-import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
@@ -234,25 +233,15 @@ private void ensureTableExists(int vectorDimension) {
createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type);
}
- createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT);
+ createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT)
+ .withColumn(this.schema.embedding, DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
for (SchemaColumn metadata : this.schema.metadataColumns) {
createTable = createTable.withColumn(metadata.name(), metadata.type());
}
- // https://datastax-oss.atlassian.net/browse/JAVA-3118
- // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT,
- // vectorDimension));
-
- StringBuilder tableStmt = new StringBuilder(createTable.asCql());
- tableStmt.setLength(tableStmt.length() - 1);
- tableStmt.append(',')
- .append(this.schema.embedding)
- .append(" vector)");
- logger.debug("Executing {}", tableStmt.toString());
- this.session.execute(tableStmt.toString());
+ logger.debug("Executing {}", createTable.asCql());
+ this.session.execute(createTable.build());
}
}
@@ -290,28 +279,12 @@ private void ensureTableColumnsExist(int vectorDimension) {
alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT);
}
if (addEmbedding) {
- // special case for embedding column, bc JAVA-3118, as above
- StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql());
- if (newColumns.isEmpty() && !addContent) {
- alterTableStmt.append(" ADD (");
- }
- else {
- alterTableStmt.setLength(alterTableStmt.length() - 1);
- alterTableStmt.append(',');
- }
- alterTableStmt.append(this.schema.embedding)
- .append(" vector)");
-
- logger.debug("Executing {}", alterTableStmt.toString());
- this.session.execute(alterTableStmt.toString());
- }
- else {
- SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
- logger.debug("Executing {}", stmt.getQuery());
- this.session.execute(stmt);
+ alterTable = alterTable.addColumn(this.schema.embedding,
+ DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
}
+ SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
+ logger.debug("Executing {}", stmt.getQuery());
+ this.session.execute(stmt);
}
}
diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java
index 9bfd6eb2060..242c467c98a 100644
--- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java
+++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java
@@ -23,7 +23,7 @@
*/
public final class CassandraImage {
- public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0");
+ public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0.2");
private CassandraImage() {