@@ -14,6 +14,19 @@ import (
1414 "github.com/tmc/langchaingo/vectorstores"
1515)
1616
17+ const (
18+ // pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent
19+ // creation of the embedding table.The same value represents the same lock.
20+ pgLockIDEmbeddingTable = 1573678846307946494
21+ // pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent
22+ // creation of the collection table.The same value represents the same lock.
23+ pgLockIDCollectionTable = 1573678846307946495
24+ // pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation
25+ // of the vector extension. The value is deliberately set to the same as python langchain
26+ // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167
27+ pgLockIDExtension = 1573678846307946496
28+ )
29+
1730var (
1831 ErrEmbedderWrongNumberVectors = errors .New ("number of vectors from embedder does not match number of documents" )
1932 ErrInvalidScoreThreshold = errors .New ("score threshold must be between 0 and 1" )
@@ -65,7 +78,7 @@ func New(ctx context.Context, opts ...Option) (Store, error) {
6578 return Store {}, err
6679 }
6780 }
68- if store . collectionUUID , err = store .createOrGetCollection (ctx ); err != nil {
81+ if err = store .createOrGetCollection (ctx ); err != nil {
6982 return Store {}, err
7083 }
7184 return store , nil
@@ -83,7 +96,7 @@ func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error {
8396 // https://github.com/langchain-ai/langchain/issues/12933
8497 // For more information see:
8598 // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
86- if _ , err := tx .Exec (ctx , "SELECT pg_advisory_xact_lock(1573678846307946495)" ); err != nil {
99+ if _ , err := tx .Exec (ctx , "SELECT pg_advisory_xact_lock($1)" , pgLockIDExtension ); err != nil {
87100 return err
88101 }
89102 if _ , err := tx .Exec (ctx , "CREATE EXTENSION IF NOT EXISTS vector" ); err != nil {
@@ -104,7 +117,7 @@ func (s Store) createCollectionTableIfNotExists(ctx context.Context) error {
104117 // https://github.com/langchain-ai/langchain/issues/12933
105118 // For more information see:
106119 // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
107- if _ , err = tx .Exec (ctx , "SELECT pg_advisory_xact_lock(1573678846307946494)" ); err != nil {
120+ if _ , err = tx .Exec (ctx , "SELECT pg_advisory_xact_lock($1)" , pgLockIDCollectionTable ); err != nil {
108121 return err
109122 }
110123 sql := fmt .Sprintf (`CREATE TABLE IF NOT EXISTS %s (
@@ -130,7 +143,7 @@ func (s Store) createEmbeddingTableIfNotExists(ctx context.Context) error {
130143 // https://github.com/langchain-ai/langchain/issues/12933
131144 // For more information see:
132145 // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
133- if _ , err := tx .Exec (ctx , "SELECT pg_advisory_xact_lock(1573678846307946493)" ); err != nil {
146+ if _ , err := tx .Exec (ctx , "SELECT pg_advisory_xact_lock($1)" , pgLockIDEmbeddingTable ); err != nil {
134147 return err
135148 }
136149 sql := fmt .Sprintf (`CREATE TABLE IF NOT EXISTS %s (
@@ -237,11 +250,11 @@ FROM (
237250WHERE %s
238251ORDER BY
239252 data.distance
240- LIMIT %d ` , s .embeddingTableName ,
253+ LIMIT $2 ` , s .embeddingTableName ,
241254 s .embeddingTableName ,
242255 s .collectionTableName , s .embeddingTableName , s .collectionTableName , s .collectionTableName , collectionName ,
243- whereQuery , numDocuments )
244- rows , err := tx .Query (ctx , sql , pgvector .NewVector (embedderData ))
256+ whereQuery )
257+ rows , err := tx .Query (ctx , sql , pgvector .NewVector (embedderData ), numDocuments )
245258 if err != nil {
246259 return nil , err
247260 }
@@ -274,30 +287,21 @@ func (s Store) DropTables(ctx context.Context) error {
274287}
275288
276289func (s Store ) RemoveCollection (ctx context.Context ) error {
277- _ , err := s .conn .Exec (ctx , fmt .Sprintf (`DELETE FROM %s WHERE name = '%s' ` , s .collectionTableName , s .collectionName ) )
290+ _ , err := s .conn .Exec (ctx , fmt .Sprintf (`DELETE FROM %s WHERE name = $1 ` , s .collectionTableName ) , s .collectionName )
278291 return err
279292}
280293
281- func (s Store ) createOrGetCollection (ctx context.Context ) ( string , error ) {
294+ func (s * Store ) createOrGetCollection (ctx context.Context ) error {
282295 sql := fmt .Sprintf (`INSERT INTO %s (uuid, name, cmetadata)
283296 VALUES($1, $2, $3) ON CONFLICT DO NOTHING` , s .collectionTableName )
284- _ , err := s .conn .Exec (ctx , sql , uuid .New ().String (), s .collectionName , s .collectionMetadata )
285- if err != nil {
286- return "" , err
297+ if _ , err := s .conn .Exec (ctx , sql , uuid .New ().String (), s .collectionName , s .collectionMetadata ); err != nil {
298+ return err
287299 }
288300 sql = fmt .Sprintf (`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1` , s .collectionTableName )
289- rows , err := s .conn .Query (ctx , sql , s .collectionName )
290- if err != nil {
291- return "" , err
292- }
293- defer rows .Close ()
294- var collectionUUID string
295- for rows .Next () {
296- if err = rows .Scan (& collectionUUID ); err != nil {
297- return "" , err
298- }
301+ if err := s .conn .QueryRow (ctx , sql , s .collectionName ).Scan (& s .collectionUUID ); err != nil {
302+ return err
299303 }
300- return collectionUUID , nil
304+ return nil
301305}
302306
303307// getOptions applies given options to default Options and returns it
0 commit comments