-
Notifications
You must be signed in to change notification settings - Fork 54
2 additional methods to help with managing of collections: GetAllDocuments & GetDocumentsByMetadata #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
philippgille
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi 👋 , Thank you very much for the contribution! I think these methods are very useful and a good feature addition to the library. And nice to hear that the library is useful in your gaming agent project!
I did a first review pass on the implementation, and will check the tests in more detail later.
Regarding performance, the search method might benefit from parallelization, but that can be added later without affecting the method signatures, so no need to optimize yet.
collection.go
Outdated
| return res, nil | ||
| } | ||
|
|
||
| func (c *Collection) GetAllDocuments(_ context.Context, fetchDeep bool) ([]Document, error) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about dropping the fetchDeep param, and always using the clone? That would match what Collection.GetByID currently does.
Or do you already use both in your app?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not an issue to refactor my code as this lib evolves. :-)
To match both the existing paradigm to avoid unnecessary params that change the behaviour of the function; but also provide a fetch without copying the whole metadata (can get bloaty on large collections), I refactored the shallow clone into a separate method ListDocumentsShort, so both is possible to do but with optimized performance cycles. Happy to hear what you think of this idea. :-)
…d .idea folder to .gitignore
Thanks for you quick reply. See my comments on the individual points you mentioned in your review. I figured I can do further optimization and I put the clone operations for deep and shallow clone into separate methods. I also created benchmarks for it, and on my machine I got the following results: CloneDocumentShort seems to yield a significant improvement in execution time, so I'd consider it worthwhile to be kept and added to the repo for use cases when metadata or embeddings aren't needed. |
collection.go
Outdated
| // The metadata tags must match the params specified in the where argument in both key and value | ||
| // The returned documents are a copy of the original document, so they can be safely | ||
| // modified without affecting the collection. | ||
| func (c *Collection) GetByMetadata(_ context.Context, where map[string]string) ([]Document, error) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks so much for this PR! Should this method also be renamed to List, since it returns a list of documents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for your review. I have no hard feelings on this one, but since we had GetByID already, which fetches collection items by ID, and this one fetches items by metadata, I thought it makes most sense to name it GetByMetadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On one hand the existing GetByID only returns one item, whereas this new method returns a slice, so it can be argued that the List... naming makes more sense.
On the other hand the existing List... methods always return all items and not a subset like this new method.
This new method maybe most similar to the existing Query... methods, just not using the embeddings for querying.
I don't want this to block the PR, so I'll think a bit more about the naming after merging, and maybe rename before the next release. Feel free to still provide your input in this comment thread.
|
Hey @philippgille , just wanted to check in with you whether this PR is fine for merging now? I applied all the changes you requested; but I'm not sure what the PR process for this repo is like. Am I supposed to resolve the conversations myself or should a maintainer do that? :-) |
|
Sorry I didn't get back to this yet, having a look now |
|
I'd change some things, but if I ask for changes and you push new commits, I might be slow again to review, and I don't want to let you wait much longer. I have prepared a diff, and can apply that to this PR, if you're okay with that (as your PR originates from your main branch, it will affect that one if I'm not mistaken!). The changes:
I'm pasting the diff here. For easier review you can copy the diff text into your clipboard, and then in your terminal do something like the following to apply it locally and review in your IDE or so: git apply <<'EOF'
[paste diff here]
EOFdiffdiff --git a/collection.go b/collection.go
index e80ebf8..eb7fd4e 100644
--- a/collection.go
+++ b/collection.go
@@ -305,31 +305,46 @@ func (c *Collection) ListIDs(_ context.Context) []string {
return ids
}
-// ListDocuments returns all documents in the collection.
-func (c *Collection) ListDocuments(_ context.Context) ([]Document, error) {
+// ListDocuments returns all documents in the collection. The returned documents
+// are a deep copy of the original ones, so you can modify them without affecting
+// the collection.
+func (c *Collection) ListDocuments(_ context.Context) ([]*Document, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
- results := make([]Document, 0, len(c.documents))
+ results := make([]*Document, 0, len(c.documents))
for _, doc := range c.documents {
- // Clone the document to avoid concurrent modification by reading goroutine
- docCopy := cloneDocument(doc)
+ docCopy := cloneDocument(doc) // Deep copy
results = append(results, docCopy)
}
return results, nil
}
-// ListDocumentsShort performs a shallow fetch on all documents in the collection,
-// returning only the document IDs and content, but not the embedding or metadata values.
-func (c *Collection) ListDocumentsShort(_ context.Context) ([]Document, error) {
+// ListDocumentsShallow returns all documents in the collection. The returned documents'
+// metadata and embeddings point to the original data, so modifying them will be
+// reflected in the collection.
+func (c *Collection) ListDocumentsShallow(_ context.Context) ([]*Document, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
- results := make([]Document, 0, len(c.documents))
+ results := make([]*Document, 0, len(c.documents))
for _, doc := range c.documents {
- // Clone the document to avoid concurrent modification by reading goroutine
- docCopy := cloneDocumentShort(doc)
- results = append(results, docCopy)
+ docCopy := *doc // Shallow copy
+ results = append(results, &docCopy)
+ }
+ return results, nil
+}
+
+// ListDocumentsPartial returns a partial version of all documents in the collection,
+// containing only the ID and content, but not the embedding or metadata values.
+func (c *Collection) ListDocumentsPartial(_ context.Context) ([]*Document, error) {
+ c.documentsLock.RLock()
+ defer c.documentsLock.RUnlock()
+
+ results := make([]*Document, 0, len(c.documents))
+ for _, doc := range c.documents {
+ partialDoc := makePartialDocument(doc) // Shallow copy
+ results = append(results, partialDoc)
}
return results, nil
}
@@ -347,23 +362,23 @@ func (c *Collection) GetByID(_ context.Context, id string) (Document, error) {
doc, ok := c.documents[id]
if ok {
- // Clone the document to avoid concurrent modification by reading goroutine
res := cloneDocument(doc)
- return res, nil
+ return *res, nil
}
return Document{}, fmt.Errorf("document with ID '%v' not found", id)
}
-// GetByMetadata returns a set of documents by their metadata.
-// The metadata tags must match the params specified in the where argument in both key and value
-// The returned documents are a copy of the original document, so they can be safely
-// modified without affecting the collection.
-func (c *Collection) GetByMetadata(_ context.Context, where map[string]string) ([]Document, error) {
+// GetByMetadata returns a set of documents, filtered by their metadata.
+// The metadata tags must match the params specified in the where argument in both
+// key and value.
+// The returned documents are a deep copy of the original document, so they can
+// be safely modified without affecting the collection.
+func (c *Collection) GetByMetadata(_ context.Context, where map[string]string) ([]*Document, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
- var results []Document
+ var results []*Document
for _, doc := range c.documents {
match := true
for key, value := range where {
@@ -373,8 +388,7 @@ func (c *Collection) GetByMetadata(_ context.Context, where map[string]string) (
}
}
if match {
- // Clone the document to avoid concurrent modification by reading goroutine
- docCopy := cloneDocument(doc)
+ docCopy := cloneDocument(doc) // Deep copy
results = append(results, docCopy)
}
}
@@ -644,17 +658,17 @@ func (c *Collection) persistMetadata() error {
}
// cloneDocument creates a deep copy of the given Document, including its Metadata and Embedding slices.
-func cloneDocument(doc *Document) Document {
+func cloneDocument(doc *Document) *Document {
docCopy := *doc
docCopy.Metadata = maps.Clone(doc.Metadata)
docCopy.Embedding = slices.Clone(doc.Embedding)
- return docCopy
+ return &docCopy
}
-// cloneDocumentShort creates a shallow copy of the given Document without its Metadata and Embedding slices.
-func cloneDocumentShort(doc *Document) Document {
+// makePartialDocument creates a copy of the given Document without its Metadata and Embedding slices.
+func makePartialDocument(doc *Document) *Document {
docCopy := *doc
docCopy.Metadata = nil
docCopy.Embedding = nil
- return docCopy
+ return &docCopy
}
diff --git a/collection_test.go b/collection_test.go
index 5a29114..f3ba301 100644
--- a/collection_test.go
+++ b/collection_test.go
@@ -442,36 +442,39 @@ func TestCollection_ListIDs(t *testing.T) {
}
// TestCollection_ListDocuments verifies that ListDocuments returns all documents
-// and that the returned documents are deep-copies (mutating them must not affect
+// and that the returned documents are deep copies (mutating them must not affect
// the collection’s internal state).
func TestCollection_ListDocuments(t *testing.T) {
ctx := context.Background()
- // Fixed embedding so we can compare easily.
- embedVec := []float32{0.0, 1.0, 0.0}
+ // Create collection
+ db := NewDB()
+ name := "test"
+ metadata := map[string]string{"foo": "bar"}
+ vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
- return embedVec, nil
+ return vectors, nil
}
-
- // Create collection.
- db := NewDB()
- coll, err := db.CreateCollection("test", nil, embeddingFunc)
+ c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
- t.Fatalf("unexpected error creating collection: %v", err)
+ t.Fatal("expected no error, got", err)
+ }
+ if c == nil {
+ t.Fatal("expected collection, got nil")
}
// Add two documents (one with explicit embedding, one relying on embeddingFunc).
docs := []Document{
- {ID: "1", Metadata: map[string]string{"foo": "bar"}, Embedding: embedVec, Content: "hello"},
+ {ID: "1", Metadata: map[string]string{"foo": "bar"}, Embedding: vectors, Content: "hello"},
{ID: "2", Metadata: map[string]string{"baz": "qux"}, Content: "world"},
}
for _, d := range docs {
- if err := coll.AddDocument(ctx, d); err != nil {
+ if err := c.AddDocument(ctx, d); err != nil {
t.Fatalf("unexpected error adding document %q: %v", d.ID, err)
}
}
- got, err := coll.ListDocuments(ctx)
+ got, err := c.ListDocuments(ctx)
if err != nil {
t.Fatalf("unexpected error from ListDocuments: %v", err)
}
@@ -480,21 +483,21 @@ func TestCollection_ListDocuments(t *testing.T) {
}
// Map for convenient lookup.
- deepByID := make(map[string]Document, len(got))
+ gotDocMap := make(map[string]*Document, len(got))
for _, d := range got {
- deepByID[d.ID] = d
+ gotDocMap[d.ID] = d
}
for _, want := range docs {
- got, ok := deepByID[want.ID]
+ got, ok := gotDocMap[want.ID]
if !ok {
t.Fatalf("doc %q not found", want.ID)
}
if got.Content != want.Content {
t.Fatalf("doc %q: expected content %q, got %q", want.ID, want.Content, got.Content)
}
- if !slices.Equal(got.Embedding, embedVec) {
- t.Fatalf("doc %q: embeddings differ, expected %v got %v", want.ID, embedVec, got.Embedding)
+ if !slices.Equal(got.Embedding, vectors) {
+ t.Fatalf("doc %q: embeddings differ, expected %v got %v", want.ID, vectors, got.Embedding)
}
for k, v := range want.Metadata {
if got.Metadata[k] != v {
@@ -503,47 +506,53 @@ func TestCollection_ListDocuments(t *testing.T) {
}
}
- // Mutate deep copy and ensure collection is untouched.
- got[0].Metadata["foo"] = "mutated"
- orig, _ := coll.GetByID(ctx, "1")
+ // Mutate returned document and ensure the collection's document is untouched.
+ gotDocMap["1"].Metadata["foo"] = "mutated"
+ orig, err := c.GetByID(ctx, "1")
+ if err != nil {
+ t.Fatalf("unexpected error getting document by ID: %v", err)
+ }
if orig.Metadata["foo"] != "bar" {
t.Fatalf("mutation leaked into collection: expected \"bar\", got %q", orig.Metadata["foo"])
}
}
-// TestCollection_ListDocumentsShort verifies that ListDocumentsShort returns all documents
-// and that the returned documents are deep-copies (mutating them must not affect
-// the collection’s internal state).
-func TestCollection_ListDocumentsShort(t *testing.T) {
+// TestCollection_ListDocumentsPartial verifies that ListDocumentsPartial returns
+// all documents and that the returned documents can be mutated without affecting
+// the collection’s internal state.
+func TestCollection_ListDocumentsPartial(t *testing.T) {
ctx := context.Background()
- // Fixed embedding so we can compare easily.
- embedVec := []float32{0.0, 1.0, 0.0}
+ // Create collection
+ db := NewDB()
+ name := "test"
+ metadata := map[string]string{"foo": "bar"}
+ vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
- return embedVec, nil
+ return vectors, nil
}
-
- // Create collection.
- db := NewDB()
- coll, err := db.CreateCollection("test", nil, embeddingFunc)
+ c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
- t.Fatalf("unexpected error creating collection: %v", err)
+ t.Fatal("expected no error, got", err)
+ }
+ if c == nil {
+ t.Fatal("expected collection, got nil")
}
// Add two documents (one with explicit embedding, one relying on embeddingFunc).
docs := []Document{
- {ID: "1", Metadata: map[string]string{"foo": "bar"}, Embedding: embedVec, Content: "hello"},
+ {ID: "1", Metadata: map[string]string{"foo": "bar"}, Embedding: vectors, Content: "hello"},
{ID: "2", Metadata: map[string]string{"baz": "qux"}, Content: "world"},
}
for _, d := range docs {
- if err := coll.AddDocument(ctx, d); err != nil {
+ if err := c.AddDocument(ctx, d); err != nil {
t.Fatalf("unexpected error adding document %q: %v", d.ID, err)
}
}
- got, err := coll.ListDocumentsShort(ctx)
+ got, err := c.ListDocumentsPartial(ctx)
if err != nil {
- t.Fatalf("unexpected error from ListDocumentsShort: %v", err)
+ t.Fatalf("unexpected error from ListDocumentsPartial: %v", err)
}
if len(got) != len(docs) {
t.Fatalf("expected %d docs, got %d", len(docs), len(got))
@@ -561,9 +570,15 @@ func TestCollection_ListDocumentsShort(t *testing.T) {
}
}
+ // Map for convenient lookup.
+ gotDocMap := make(map[string]*Document, len(got))
+ for _, d := range got {
+ gotDocMap[d.ID] = d
+ }
+
// Mutate deep copy and ensure collection is untouched.
- got[0].Content = "mutated"
- orig, _ := coll.GetByID(ctx, "1")
+ gotDocMap["1"].Content = "mutated"
+ orig, _ := c.GetByID(ctx, "1")
if orig.Content != "hello" {
t.Fatalf("mutation leaked into collection: expected \"hello\", got %q", orig.Content)
}
@@ -812,25 +827,33 @@ func TestCloneDocument(t *testing.T) {
}
// Mutate clone and ensure original is not affected
+ clone.ID = "doc1_clone"
clone.Metadata["foo"] = "baz"
clone.Embedding[0] = 42.0
+ clone.Content = "changed"
+ if orig.ID != "doc1" {
+ t.Fatalf("mutation leaked into original ID: expected \"doc1\", got %q", orig.ID)
+ }
if orig.Metadata["foo"] != "bar" {
t.Fatalf("mutation leaked into original Metadata: expected \"bar\", got %q", orig.Metadata["foo"])
}
if orig.Embedding[0] != 1.0 {
t.Fatalf("mutation leaked into original Embedding: expected 1.0, got %v", orig.Embedding[0])
}
+ if orig.Content != "hello" {
+ t.Fatalf("mutation leaked into original Content: expected \"hello\", got %q", orig.Content)
+ }
}
-// TestCloneDocumentShort verifies that cloneDocumentShort creates a shallow copy with nil Metadata and Embedding.
-func TestCloneDocumentShort(t *testing.T) {
+// TestMakePartialDocument verifies that makePartialDocument creates a copy with nil Metadata and Embedding.
+func TestMakePartialDocument(t *testing.T) {
orig := &Document{
ID: "doc2",
Metadata: map[string]string{"foo": "bar"},
Embedding: []float32{1.0, 2.0, 3.0},
Content: "world",
}
- clone := cloneDocumentShort(orig)
+ clone := makePartialDocument(orig)
// Check ID and Content are copied
if clone.ID != orig.ID {
@@ -961,6 +984,8 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
globalRes = res
}
+var globalDoc *Document
+
// BenchmarkCloneDocument_1 performs 1 clone per iteration.
func BenchmarkCloneDocument_1(b *testing.B) {
benchmarkCloneDocumentN(b, 1)
@@ -981,24 +1006,24 @@ func BenchmarkCloneDocument_1000(b *testing.B) {
benchmarkCloneDocumentN(b, 1000)
}
-// BenchmarkCloneDocumentShort_1 performs 1 shallow clone per iteration.
-func BenchmarkCloneDocumentShort_1(b *testing.B) {
- benchmarkCloneDocumentShortN(b, 1)
+// BenchmarkMakePartialDocument_1 performs 1 shallow clone per iteration.
+func BenchmarkMakePartialDocument_1(b *testing.B) {
+ benchmarkMakePartialDocumentN(b, 1)
}
-// BenchmarkCloneDocumentShort_10 performs 10 shallow clones per iteration.
-func BenchmarkCloneDocumentShort_10(b *testing.B) {
- benchmarkCloneDocumentShortN(b, 10)
+// BenchmarkMakePartialDocument_10 performs 10 shallow clones per iteration.
+func BenchmarkMakePartialDocument_10(b *testing.B) {
+ benchmarkMakePartialDocumentN(b, 10)
}
-// BenchmarkCloneDocumentShort_100 performs 100 shallow clones per iteration.
-func BenchmarkCloneDocumentShort_100(b *testing.B) {
- benchmarkCloneDocumentShortN(b, 100)
+// BenchmarkMakePartialDocument_100 performs 100 shallow clones per iteration.
+func BenchmarkMakePartialDocument_100(b *testing.B) {
+ benchmarkMakePartialDocumentN(b, 100)
}
-// BenchmarkCloneDocumentShort_1000 performs 1000 shallow clones per iteration.
-func BenchmarkCloneDocumentShort_1000(b *testing.B) {
- benchmarkCloneDocumentShortN(b, 1000)
+// BenchmarkMakePartialDocument_1000 performs 1000 shallow clones per iteration.
+func BenchmarkMakePartialDocument_1000(b *testing.B) {
+ benchmarkMakePartialDocumentN(b, 1000)
}
// Helper for benchmarking cloneDocument with n clones per iteration.
@@ -1009,32 +1034,32 @@ func benchmarkCloneDocumentN(b *testing.B, n int) {
Embedding: []float32{1.0, 2.0, 3.0},
Content: "benchmark content",
}
- var res Document
+ var res *Document
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < n; j++ {
res = cloneDocument(doc)
}
}
- _ = res // prevent compiler optimization
+ globalDoc = res // prevent compiler optimization
}
-// Helper for benchmarking cloneDocumentShort with n clones per iteration.
-func benchmarkCloneDocumentShortN(b *testing.B, n int) {
+// Helper for benchmarking makePartialDocument with n clones per iteration.
+func benchmarkMakePartialDocumentN(b *testing.B, n int) {
doc := &Document{
ID: "bench",
Metadata: map[string]string{"foo": "bar", "baz": "qux"},
Embedding: []float32{1.0, 2.0, 3.0},
Content: "benchmark content",
}
- var res Document
+ var res *Document
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < n; j++ {
- res = cloneDocumentShort(doc)
+ res = makePartialDocument(doc)
}
}
- _ = res // prevent compiler optimization
+ globalDoc = res // prevent compiler optimization
}
// randomString returns a random string of length n using lowercase letters and space.
|
|
I pushed the diff now. It appears above the above comment, as if it was pushed before, but that's only from my local commit date being from that time. As I warned above, this push was to your To avoid that you can create separate feature branches and create PRs to upstream from those. Then the GitHub feature "allow edits by maintainers" only allows changes to those branches, and not to your main. Merging this now. Thanks again for your contribution! 🙇♂️ |
|
PS: Now after the merge, feel free to remove my commit from your |
Hi, I am using chromem-go as the RAG engine in my gaming agent project; to enable dynamic actions, character memory and basic management of RAG data for users.
I added 2 methods to my fork which might be useful to be added to the main branch.
GetAllDocuments: Simple as it sounds. Returns all documents in a collection. Has an option for deep fetching, which includes metadata and embeddings. If it's set to false, it will do shallow fetching with only ID and content string.GetDocumentsByMetadata: Fetches all documents in a collection which match one or more tags without performing a similarity search likeQuery. I use it to sync with an external source of truth, e.g. if actions get added / dropped or changed on game side, so my application always is in sync when performing queries.I didn't do any performance testing (yet), so I don't know how this behaves on a larger scale.