@@ -107,6 +107,118 @@ def search_index(self, index, xq, normalize=False, k=100):
107107 distances , indices = index .search (faiss_query_vector , k )
108108 return distances , indices
109109
110+ class SparseVector (object ):
111+ def __init__ (self , version = "small" ):
112+ """
113+ version: one of {"small", "1M", "full"}
114+ """
115+ # Supported sparse dataset versions
116+ versions = {
117+ "small" : (100000 , "base_small.csr.gz" , "base_small.dev.gt" ),
118+ "1M" : (1000000 , "base_1M.csr.gz" , "base_1M.dev.gt" ),
119+ "full" : (8841823 , "base_full.csr.gz" , "base_full.dev.gt" ),
120+ }
121+ assert version in versions , f'version="{ version } " is invalid. Please choose one of { list (versions .keys ())} .'
122+
123+ self .version = version
124+ self .nb = versions [version ][0 ]
125+ self .ds_fn = versions [version ][1 ]
126+ self .gt_fn = versions [version ][2 ]
127+ self .nq = 6980
128+ self .dataset_location = os .path .abspath ("data/sparse" )
129+ self .base_url = "https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/"
130+ self .qs_fn = "queries.dev.csr.gz"
131+
132+ # Make sure directory exists
133+ if not os .path .exists (self .dataset_location ):
134+ os .makedirs (self .dataset_location )
135+
136+ def download_dataset (self ):
137+ """
138+ Download and unzip the sparse dataset and queries file, if not present.
139+ """
140+ import gzip
141+
142+ file_list = [self .ds_fn , self .qs_fn , self .gt_fn ]
143+ for fn in file_list :
144+ # Download gzipped file
145+ url = os .path .join (self .base_url , fn )
146+ gz_path = os .path .join (self .dataset_location , fn )
147+ unzipped_path = gz_path [:- 3 ] if gz_path .endswith (".gz" ) else gz_path
148+
149+ if os .path .exists (unzipped_path ):
150+ print (f"Unzipped file already exists: { unzipped_path } " )
151+ continue
152+ if os .path .exists (gz_path ):
153+ print (f"Gzipped file already exists: { gz_path } " )
154+ else :
155+ print (f"Downloading { url } -> { gz_path } " )
156+ request .urlretrieve (url , gz_path , quiet = True )
157+ # Unzip if needed
158+ if gz_path .endswith (".gz" ):
159+ print (f"Unzipping { gz_path } ..." )
160+ with gzip .open (gz_path , 'rb' ) as f_in , open (unzipped_path , 'wb' ) as f_out :
161+ shutil .copyfileobj (f_in , f_out )
162+ os .remove (gz_path )
163+ print (f"Done unzipping: { unzipped_path } " )
164+
165+ def read_vector (self ):
166+ """
167+ Read the base dataset vectors as csr_matrix.
168+ """
169+ from scipy .sparse import csr_matrix
170+
171+ fname = os .path .join (self .dataset_location , self .ds_fn .replace (".gz" , "" ))
172+ if not os .path .exists (fname ):
173+ raise FileNotFoundError (f"Dataset file not found: { fname } " )
174+ return self ._read_sparse_matrix (fname )
175+
176+ def read_query (self ):
177+ """
178+ Read the queries as csr_matrix.
179+ """
180+ fname = os .path .join (self .dataset_location , self .qs_fn .replace (".gz" , "" ))
181+ if not os .path .exists (fname ):
182+ raise FileNotFoundError (f"Query file not found: { fname } " )
183+ return self ._read_sparse_matrix (fname )
184+
185+ def read_groundtruth (self ):
186+ """
187+ Read the groundtruth file and return (I, D) as in knn_result_read.
188+ """
189+ fname = os .path .join (self .dataset_location , self .gt_fn )
190+ if not os .path .exists (fname ):
191+ raise FileNotFoundError (f"Groundtruth file not found: { fname } " )
192+ return self ._knn_result_read (fname )
193+
194+ def _read_sparse_matrix (self , fname ):
195+ """
196+ Internal: Read a CSR matrix in 'spmat'/sparse format.
197+ """
198+ from scipy .sparse import csr_matrix
199+ with open (fname , "rb" ) as f :
200+ sizes = np .fromfile (f , dtype = 'int64' , count = 3 )
201+ nrow , ncol , nnz = sizes
202+ indptr = np .fromfile (f , dtype = 'int64' , count = nrow + 1 )
203+ indices = np .fromfile (f , dtype = 'int32' , count = nnz )
204+ data = np .fromfile (f , dtype = 'float32' , count = nnz )
205+ return csr_matrix ((data , indices , indptr ), shape = (nrow , ncol ))
206+
207+ def _knn_result_read (self , fname ):
208+ """
209+ Internal: Read the groundtruth as per scripts/sparse.py/knn_result_read.
210+ Returns (I, D) where I are the indices and D the distances.
211+ """
212+ n , d = np .fromfile (fname , dtype = "uint32" , count = 2 )
213+ expected_size = 8 + n * d * (4 + 4 )
214+ file_size = os .stat (fname ).st_size
215+ assert file_size == expected_size , f"File size mismatch: expected { expected_size } , got { file_size } "
216+ with open (fname , "rb" ) as f :
217+ f .seek (8 )
218+ I = np .fromfile (f , dtype = "int32" , count = n * d ).reshape (n , d )
219+ D = np .fromfile (f , dtype = "float32" , count = n * d ).reshape (n , d )
220+ return I , D
221+
110222class SiftVector (object ):
111223 def __init__ (self ):
112224 self .dataset = "siftsmall"
@@ -147,36 +259,83 @@ def encode_vector(self, vector, is_bigendian=False):
147259 endian = '<'
148260 buf = struct .pack (f'{ endian } %sf' % len (vector ), * vector )
149261 return base64 .b64encode (buf ).decode ()
150- def load_batch_documents (self , cluster , docs , batch , is_xattr = False , is_base64 = False , is_bigendian = False , bucket = 'default' , scope = '_default' , collection = '_default' , vector_field = 'vec' ):
262+
263+ def encode_sparse_vector (self , indices , values , is_bigendian = False ):
264+ # encodes two lists: indices (int) and values (float)
265+ if is_bigendian :
266+ endian = '>'
267+ else :
268+ endian = '<'
269+ idx_buf = struct .pack (f"{ endian } { len (indices )} i" , * indices )
270+ val_buf = struct .pack (f"{ endian } { len (values )} f" , * values )
271+ return [base64 .b64encode (idx_buf ).decode (), base64 .b64encode (val_buf ).decode ()]
272+
273+ def load_batch_documents (
274+ self ,
275+ cluster ,
276+ docs ,
277+ batch ,
278+ is_xattr = False ,
279+ is_base64 = False ,
280+ is_bigendian = False ,
281+ vector_type = 'dense' ,
282+ bucket = 'default' ,
283+ scope = '_default' ,
284+ collection = '_default' ,
285+ vector_field = 'vec'
286+ ):
287+ """
288+ Loads batch docs into Couchbase, supporting both dense (default) and sparse.
289+ - If vector_type='dense', `docs` should be a sequence of ndarrays.
290+ - If vector_type='sparse', `docs` should be a sequence of 2-tuples/lists: (indices, values)
291+ """
151292 cb = cluster .bucket (bucket )
152293 cb_coll = cb .scope (scope ).collection (collection )
153294 documents = {}
154295 for is1 , size in enumerate (cfg ["sizes" ]):
155296 for ib , brand in enumerate (cfg ["brands" ]):
156297 documents = {}
157298 for idx , x in enumerate (docs ):
158- vector = x .tolist ()
159- if is_base64 :
160- vector = self .encode_vector (vector , is_bigendian )
299+ if vector_type == 'sparse' :
300+ # Expect x to be (indices, values) or a scipy.sparse row-like object
301+ if hasattr (x , "indices" ) and hasattr (x , "data" ):
302+ indices = x .indices .tolist ()
303+ values = x .data .tolist ()
304+ elif isinstance (x , (tuple , list )) and len (x ) == 2 :
305+ indices , values = x
306+ else :
307+ raise ValueError ("When vector_type='sparse', each doc must be (indices, values)" )
308+ if is_base64 :
309+ vector = self .encode_sparse_vector (indices , values , is_bigendian )
310+ else :
311+ vector = [indices , values ]
312+ else :
313+ vector = x .tolist ()
314+ if is_base64 :
315+ vector = self .encode_vector (vector , is_bigendian )
161316 key = f"vec_{ brand } _{ size } _{ idx + batch } "
162317 doc = {
163318 "id" : idx + batch ,
164- "size" :size ,
165- "sizeidx" :is1 ,
166- "brand" :brand ,
167- "brandidx" :ib ,
319+ "size" : size ,
320+ "sizeidx" : is1 ,
321+ "brand" : brand ,
322+ "brandidx" : ib ,
168323 vector_field : vector
169324 }
170- # if is_xattr:
171- # del doc[vector_field]
325+ # if is_xattr: remove vector_field from doc and store as xattr below
172326 documents [key ] = doc
173327 try :
174- upsert = cb_coll .upsert_multi (documents )
328+ cb_coll .upsert_multi (documents )
175329 except Exception as e :
176330 print (e )
177331 if is_xattr :
178332 for key in documents :
179- cb_coll .mutate_in (key , [SD .upsert (vector_field , documents [key ][vector_field ], xattr = is_xattr ), SD .remove (vector_field )])
333+ cb_coll .mutate_in (
334+ key ,
335+ [SD .upsert (vector_field , documents [key ][vector_field ], xattr = True ),
336+ SD .remove (vector_field )]
337+ )
338+
180339 def multi_upsert_document_into_cb (self , cb_coll , documents ):
181340 try :
182341 cb_coll .upsert_multi (documents )
0 commit comments