Skip to content

Use cassandra-java-driver's QueryBuidler vector support in 4.18.2 #1817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
<!-- also managed by boot bom -->
<oracle.version>23.4.0.24.05</oracle.version>
<postgresql.version>42.7.2</postgresql.version>
<cassandra.java-driver.version>4.18.1</cassandra.java-driver.version>
<cassandra.java-driver.version>4.18.2-SNAPSHOT</cassandra.java-driver.version>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo

<elasticsearch-java.version>8.13.3</elasticsearch-java.version>
<spring-retry.version>2.0.9</spring-retry.version>
<jackson.version>2.16.1</jackson.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@
import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.select.Selector;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
Expand Down Expand Up @@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme

public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";

private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";

private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class);

private static Map<Similarity, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(Similarity.COSINE,
Expand All @@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme

private final PreparedStatement deleteStmt;

private final String similarityStmt;

private final Similarity similarity;

private final BatchingStrategy batchingStrategy;
Expand Down Expand Up @@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe
.get();

this.similarity = getIndexSimilarity(cassandraMetadata);
this.similarityStmt = similaritySearchStatement();

this.filterExpressionConverter = new CassandraFilterExpressionConverter(
cassandraMetadata.getColumns().values());
Expand Down Expand Up @@ -232,21 +231,14 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
Preconditions.checkArgument(request.getTopK() <= 1000);
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
CqlVector<Float> cqlVector = CqlVector.newInstance(embedding);

String whereClause = "";
if (request.hasFilterExpression()) {
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
if (!expression.isBlank()) {
whereClause = String.format("where %s", expression);
}
}

String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK());
String cql = createSimilaritySearchCql(request, cqlVector, request.getTopK());
List<Document> documents = new ArrayList<>();
logger.trace("Executing {}", query);
SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH);
logger.trace("Executing {}", cql);

for (Row row : this.conf.session.execute(s)) {
ResultSet result = this.conf.session
.execute(SimpleStatement.newInstance(cql).setExecutionProfileName(DRIVER_PROFILE_SEARCH));

for (Row row : result) {
float score = row.getFloat(0);
if (score < request.getSimilarityThreshold()) {
break;
Expand Down Expand Up @@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
});
}

private String similaritySearchStatement() {
StringBuilder ids = new StringBuilder();
for (var m : this.conf.schema.partitionKeys()) {
ids.append(m.name()).append(',');
}
for (var m : this.conf.schema.clusteringKeys()) {
ids.append(m.name()).append(',');
}
ids.deleteCharAt(ids.length() - 1);
private String createSimilaritySearchCql(SearchRequest request, CqlVector<Float> cqlVector, int topK) {

String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase())
.append('(')
.append(this.conf.schema.embedding())
.append(",?)")
.toString();
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table())
.function("similarity_" + this.similarity.toString().toLowerCase(),
Selector.column(this.conf.schema.embedding()), literal(cqlVector));

StringBuilder extraSelectFields = new StringBuilder();
for (var c : this.conf.schema.partitionKeys()) {
stmt = stmt.column(c.name());
}
for (var c : this.conf.schema.clusteringKeys()) {
stmt = stmt.column(c.name());
}
stmt = stmt.column(this.conf.schema.content());
for (var m : this.conf.schema.metadataColumns()) {
extraSelectFields.append(',').append(m.name());
stmt = stmt.column(m.name());
}
if (this.conf.returnEmbeddings) {
extraSelectFields.append(',').append(this.conf.schema.embedding());
stmt = stmt.column(this.conf.schema.embedding());
}

// java-driver-query-builder doesn't support orderByAnnOf yet
String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(),
extraSelectFields.toString(), this.conf.schema.keyspace(), this.conf.schema.table(),
this.conf.schema.embedding());

query = query.replace("?", "%s");
logger.debug("preparing {}", query);
return query;
// the filterExpression is a string so we go back to building a CQL string
String whereClause = "";
if (request.hasFilterExpression()) {
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
if (!expression.isBlank()) {
whereClause = String.format("WHERE %s", expression);
}
}
String cql = stmt.orderByAnnOf(this.conf.schema.embedding(), cqlVector).limit(topK).asCql();
return cql.replace(" ORDER ", whereClause + " ORDER ");
}

private String getDocumentId(Row row) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
Expand Down Expand Up @@ -234,25 +233,15 @@ private void ensureTableExists(int vectorDimension) {
createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type);
}

createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT);
createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT)
.withColumn(this.schema.embedding, DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));

for (SchemaColumn metadata : this.schema.metadataColumns) {
createTable = createTable.withColumn(metadata.name(), metadata.type());
}

// https://datastax-oss.atlassian.net/browse/JAVA-3118
// .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT,
// vectorDimension));

StringBuilder tableStmt = new StringBuilder(createTable.asCql());
tableStmt.setLength(tableStmt.length() - 1);
tableStmt.append(',')
.append(this.schema.embedding)
.append(" vector<float,")
.append(vectorDimension)
.append(">)");
logger.debug("Executing {}", tableStmt.toString());
this.session.execute(tableStmt.toString());
logger.debug("Executing {}", createTable.asCql());
this.session.execute(createTable.build());
}
}

Expand Down Expand Up @@ -290,28 +279,12 @@ private void ensureTableColumnsExist(int vectorDimension) {
alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT);
}
if (addEmbedding) {
// special case for embedding column, bc JAVA-3118, as above
StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql());
if (newColumns.isEmpty() && !addContent) {
alterTableStmt.append(" ADD (");
}
else {
alterTableStmt.setLength(alterTableStmt.length() - 1);
alterTableStmt.append(',');
}
alterTableStmt.append(this.schema.embedding)
.append(" vector<float,")
.append(vectorDimension)
.append(">)");

logger.debug("Executing {}", alterTableStmt.toString());
this.session.execute(alterTableStmt.toString());
}
else {
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
logger.debug("Executing {}", stmt.getQuery());
this.session.execute(stmt);
alterTable = alterTable.addColumn(this.schema.embedding,
DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
}
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
logger.debug("Executing {}", stmt.getQuery());
this.session.execute(stmt);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public final class CassandraImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0.2");

private CassandraImage() {

Expand Down