Skip to content

Commit d9641b0

Browse files
committed
fix: Set pgLockID constant and Simplify createOrGetCollection func
Signed-off-by: Abirdcfly <[email protected]>
1 parent 1346747 commit d9641b0

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

vectorstores/pgvector/options.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"os"
77

8+
"github.com/jackc/pgx/v5"
89
"github.com/tmc/langchaingo/embeddings"
910
)
1011

@@ -52,14 +53,14 @@ func WithCollectionName(name string) Option {
5253
// WithEmbeddingTableName is an option for specifying the embedding table name.
5354
func WithEmbeddingTableName(name string) Option {
5455
return func(p *Store) {
55-
p.embeddingTableName = name
56+
p.embeddingTableName = pgx.Identifier{name}.Sanitize()
5657
}
5758
}
5859

5960
// WithCollectionTableName is an option for specifying the collection table name.
6061
func WithCollectionTableName(name string) Option {
6162
return func(p *Store) {
62-
p.collectionTableName = name
63+
p.collectionTableName = pgx.Identifier{name}.Sanitize()
6364
}
6465
}
6566

vectorstores/pgvector/pgvector.go

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ import (
1414
"github.com/tmc/langchaingo/vectorstores"
1515
)
1616

17+
const (
18+
// pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent
19+
// creation of the embedding table.The same value represents the same lock.
20+
pgLockIDEmbeddingTable = 1573678846307946494
21+
// pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent
22+
// creation of the collection table.The same value represents the same lock.
23+
pgLockIDCollectionTable = 1573678846307946495
24+
// pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation
25+
// of the vector extension. The value is deliberately set to the same as python langchain
26+
// https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167
27+
pgLockIDExtension = 1573678846307946496
28+
)
29+
1730
var (
1831
ErrEmbedderWrongNumberVectors = errors.New("number of vectors from embedder does not match number of documents")
1932
ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1")
@@ -65,7 +78,7 @@ func New(ctx context.Context, opts ...Option) (Store, error) {
6578
return Store{}, err
6679
}
6780
}
68-
if store.collectionUUID, err = store.createOrGetCollection(ctx); err != nil {
81+
if err = store.createOrGetCollection(ctx); err != nil {
6982
return Store{}, err
7083
}
7184
return store, nil
@@ -83,7 +96,7 @@ func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error {
8396
// https://github.com/langchain-ai/langchain/issues/12933
8497
// For more information see:
8598
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
86-
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946495)"); err != nil {
99+
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDExtension); err != nil {
87100
return err
88101
}
89102
if _, err := tx.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
@@ -104,7 +117,7 @@ func (s Store) createCollectionTableIfNotExists(ctx context.Context) error {
104117
// https://github.com/langchain-ai/langchain/issues/12933
105118
// For more information see:
106119
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
107-
if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946494)"); err != nil {
120+
if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDCollectionTable); err != nil {
108121
return err
109122
}
110123
sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
@@ -130,7 +143,7 @@ func (s Store) createEmbeddingTableIfNotExists(ctx context.Context) error {
130143
// https://github.com/langchain-ai/langchain/issues/12933
131144
// For more information see:
132145
// https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
133-
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock(1573678846307946493)"); err != nil {
146+
if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDEmbeddingTable); err != nil {
134147
return err
135148
}
136149
sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
@@ -237,11 +250,11 @@ FROM (
237250
WHERE %s
238251
ORDER BY
239252
data.distance
240-
LIMIT %d`, s.embeddingTableName,
253+
LIMIT $2`, s.embeddingTableName,
241254
s.embeddingTableName,
242255
s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName,
243-
whereQuery, numDocuments)
244-
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData))
256+
whereQuery)
257+
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData), numDocuments)
245258
if err != nil {
246259
return nil, err
247260
}
@@ -274,30 +287,21 @@ func (s Store) DropTables(ctx context.Context) error {
274287
}
275288

276289
func (s Store) RemoveCollection(ctx context.Context) error {
277-
_, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = '%s'`, s.collectionTableName, s.collectionName))
290+
_, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = $1`, s.collectionTableName), s.collectionName)
278291
return err
279292
}
280293

281-
func (s Store) createOrGetCollection(ctx context.Context) (string, error) {
294+
func (s *Store) createOrGetCollection(ctx context.Context) error {
282295
sql := fmt.Sprintf(`INSERT INTO %s (uuid, name, cmetadata)
283296
VALUES($1, $2, $3) ON CONFLICT DO NOTHING`, s.collectionTableName)
284-
_, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata)
285-
if err != nil {
286-
return "", err
297+
if _, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata); err != nil {
298+
return err
287299
}
288300
sql = fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.collectionTableName)
289-
rows, err := s.conn.Query(ctx, sql, s.collectionName)
290-
if err != nil {
291-
return "", err
292-
}
293-
defer rows.Close()
294-
var collectionUUID string
295-
for rows.Next() {
296-
if err = rows.Scan(&collectionUUID); err != nil {
297-
return "", err
298-
}
301+
if err := s.conn.QueryRow(ctx, sql, s.collectionName).Scan(&s.collectionUUID); err != nil {
302+
return err
299303
}
300-
return collectionUUID, nil
304+
return nil
301305
}
302306

303307
// getOptions applies given options to default Options and returns it

0 commit comments

Comments
 (0)