Skip to content
44 changes: 44 additions & 0 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,50 @@ func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativ
return res, nil
}

func (c *Collection) GetAllDocuments(_ context.Context, fetchDeep bool) ([]Document, error) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about dropping the fetchDeep param, and always using the clone? That would match what Collection.GetByID currently does.

Or do you already use both in your app?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an issue to refactor my code as this lib evolves. :-)

To match both the existing paradigm to avoid unnecessary params that change the behaviour of the function; but also provide a fetch without copying the whole metadata (can get bloaty on large collections), I refactored the shallow clone into a separate method ListDocumentsShort, so both is possible to do but with optimized performance cycles. Happy to hear what you think of this idea. :-)

c.documentsLock.RLock()
defer c.documentsLock.RUnlock()

results := make([]Document, 0, len(c.documents))
for _, doc := range c.documents {
// Clone the document to avoid concurrent modification by reading goroutine
docCopy := *doc
if fetchDeep {
docCopy.Metadata = maps.Clone(doc.Metadata)
docCopy.Embedding = slices.Clone(doc.Embedding)
} else {
docCopy.Metadata = nil
docCopy.Embedding = nil
}
results = append(results, docCopy)
}
return results, nil
}

func (c *Collection) GetDocumentsByMetadata(_ context.Context, where map[string]string) ([]Document, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()

var results []Document
for _, doc := range c.documents {
match := true
for key, value := range where {
if docVal, ok := doc.Metadata[key]; !ok || docVal != value {
match = false
break
}
}
if match {
// Clone the document to avoid concurrent modification by reading goroutine
docCopy := *doc
docCopy.Metadata = maps.Clone(doc.Metadata)
docCopy.Embedding = slices.Clone(doc.Embedding)
results = append(results, docCopy)
}
}
return results, nil
}

// getDocPath generates the path to the document file.
func (c *Collection) getDocPath(docID string) string {
safeID := hash2hex(docID)
Expand Down
139 changes: 139 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,145 @@ func TestCollection_Delete(t *testing.T) {
checkCount(0)
}

// TestCollection_GetAllDocuments verifies that GetAllDocuments returns all documents
// and that the returned documents are deep-copies (mutating them must not affect
// the collection’s internal state).
func TestCollection_GetAllDocuments(t *testing.T) {
ctx := context.Background()

// Fixed embedding so we can compare easily.
embedVec := []float32{0.0, 1.0, 0.0}
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return embedVec, nil
}

// Create collection.
db := NewDB()
coll, err := db.CreateCollection("test", nil, embeddingFunc)
if err != nil {
t.Fatalf("unexpected error creating collection: %v", err)
}

// Add two documents (one with explicit embedding, one relying on embeddingFunc).
docs := []Document{
{ID: "1", Metadata: map[string]string{"foo": "bar"}, Embedding: embedVec, Content: "hello"},
{ID: "2", Metadata: map[string]string{"baz": "qux"}, Content: "world"},
}
for _, d := range docs {
if err := coll.AddDocument(ctx, d); err != nil {
t.Fatalf("unexpected error adding document %q: %v", d.ID, err)
}
}

// ------------------------------------------------------------------
// Deep fetch (fetchDeep = true)
// ------------------------------------------------------------------
gotDeep, err := coll.GetAllDocuments(ctx, true)
if err != nil {
t.Fatalf("unexpected error from GetAllDocuments (deep): %v", err)
}
if len(gotDeep) != len(docs) {
t.Fatalf("deep: expected %d docs, got %d", len(docs), len(gotDeep))
}

// Map for convenient lookup.
deepByID := make(map[string]Document, len(gotDeep))
for _, d := range gotDeep {
deepByID[d.ID] = d
}

for _, want := range docs {
got, ok := deepByID[want.ID]
if !ok {
t.Fatalf("deep: doc %q not found", want.ID)
}
if got.Content != want.Content {
t.Fatalf("deep: doc %q: expected content %q, got %q", want.ID, want.Content, got.Content)
}
if !slices.Equal(got.Embedding, embedVec) {
t.Fatalf("deep: doc %q: embeddings differ, expected %v got %v", want.ID, embedVec, got.Embedding)
}
for k, v := range want.Metadata {
if got.Metadata[k] != v {
t.Fatalf("deep: doc %q: expected metadata %q=%q, got %q", want.ID, k, v, got.Metadata[k])
}
}
}

// Mutate deep copy and ensure collection is untouched.
gotDeep[0].Metadata["foo"] = "mutated"
orig, _ := coll.GetByID(ctx, "1")
if orig.Metadata["foo"] != "bar" {
t.Fatalf("deep: mutation leaked into collection: expected \"bar\", got %q", orig.Metadata["foo"])
}

// ------------------------------------------------------------------
// Shallow fetch (fetchDeep = false)
// ------------------------------------------------------------------
gotShallow, err := coll.GetAllDocuments(ctx, false)
if err != nil {
t.Fatalf("unexpected error from GetAllDocuments (shallow): %v", err)
}
if len(gotShallow) != len(docs) {
t.Fatalf("shallow: expected %d docs, got %d", len(docs), len(gotShallow))
}
for _, d := range gotShallow {
if d.Metadata != nil {
t.Fatalf("shallow: expected Metadata to be nil, got %#v", d.Metadata)
}
if d.Embedding != nil {
t.Fatalf("shallow: expected Embedding to be nil, got %#v", d.Embedding)
}
// Content and ID must still be present.
if d.Content == "" || d.ID == "" {
t.Fatalf("shallow: expected ID and Content to be set, got %+v", d)
}
}
}

func TestCollection_GetDocumentsByMetadata(t *testing.T) {
ctx := context.Background()

// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{1.0, 2.0, 3.0}, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Add documents
docs := []Document{
{ID: "1", Metadata: map[string]string{"type": "article", "lang": "en"}, Content: "Hello World"},
{ID: "2", Metadata: map[string]string{"type": "article", "lang": "fr"}, Content: "Bonjour le monde"},
{ID: "3", Metadata: map[string]string{"type": "blog", "lang": "en"}, Content: "My blog post"},
}
for _, doc := range docs {
err := c.AddDocument(ctx, doc)
if err != nil {
t.Fatal("expected no error, got", err)
}
}

// Filter by metadata
where := map[string]string{"type": "article", "lang": "en"}
results, err := c.GetDocumentsByMetadata(ctx, where)
if err != nil {
t.Fatal("expected no error, got", err)
}

if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if results[0].ID != "1" {
t.Fatalf("expected document ID '1', got '%s'", results[0].ID)
}
}

// Global var for assignment in the benchmark to avoid compiler optimizations.
var globalRes []Result

Expand Down