Skip to content

Commit 9d27f5b

Browse files
committed
Add sparse vector support
Change-Id: I0aea36341d2f409b6921460252dfe9b50613a15a Reviewed-on: https://review.couchbase.org/c/testrunner/+/237455 Tested-by: Pierre Regazzoni <[email protected]> Reviewed-by: Ajay Bhullar <[email protected]>
1 parent c2ffa29 commit 9d27f5b

File tree

2 files changed

+187
-18
lines changed

2 files changed

+187
-18
lines changed

lib/vector/vector.py

Lines changed: 171 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,118 @@ def search_index(self, index, xq, normalize=False, k=100):
107107
distances, indices = index.search(faiss_query_vector, k)
108108
return distances, indices
109109

110+
class SparseVector(object):
111+
def __init__(self, version="small"):
112+
"""
113+
version: one of {"small", "1M", "full"}
114+
"""
115+
# Supported sparse dataset versions
116+
versions = {
117+
"small": (100000, "base_small.csr.gz", "base_small.dev.gt"),
118+
"1M": (1000000, "base_1M.csr.gz", "base_1M.dev.gt"),
119+
"full": (8841823, "base_full.csr.gz", "base_full.dev.gt"),
120+
}
121+
assert version in versions, f'version="{version}" is invalid. Please choose one of {list(versions.keys())}.'
122+
123+
self.version = version
124+
self.nb = versions[version][0]
125+
self.ds_fn = versions[version][1]
126+
self.gt_fn = versions[version][2]
127+
self.nq = 6980
128+
self.dataset_location = os.path.abspath("data/sparse")
129+
self.base_url = "https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/"
130+
self.qs_fn = "queries.dev.csr.gz"
131+
132+
# Make sure directory exists
133+
if not os.path.exists(self.dataset_location):
134+
os.makedirs(self.dataset_location)
135+
136+
def download_dataset(self):
137+
"""
138+
Download and unzip the sparse dataset and queries file, if not present.
139+
"""
140+
import gzip
141+
142+
file_list = [self.ds_fn, self.qs_fn, self.gt_fn]
143+
for fn in file_list:
144+
# Download gzipped file
145+
url = os.path.join(self.base_url, fn)
146+
gz_path = os.path.join(self.dataset_location, fn)
147+
unzipped_path = gz_path[:-3] if gz_path.endswith(".gz") else gz_path
148+
149+
if os.path.exists(unzipped_path):
150+
print(f"Unzipped file already exists: {unzipped_path}")
151+
continue
152+
if os.path.exists(gz_path):
153+
print(f"Gzipped file already exists: {gz_path}")
154+
else:
155+
print(f"Downloading {url} -> {gz_path}")
156+
request.urlretrieve(url, gz_path, quiet=True)
157+
# Unzip if needed
158+
if gz_path.endswith(".gz"):
159+
print(f"Unzipping {gz_path} ...")
160+
with gzip.open(gz_path, 'rb') as f_in, open(unzipped_path, 'wb') as f_out:
161+
shutil.copyfileobj(f_in, f_out)
162+
os.remove(gz_path)
163+
print(f"Done unzipping: {unzipped_path}")
164+
165+
def read_vector(self):
166+
"""
167+
Read the base dataset vectors as csr_matrix.
168+
"""
169+
from scipy.sparse import csr_matrix
170+
171+
fname = os.path.join(self.dataset_location, self.ds_fn.replace(".gz", ""))
172+
if not os.path.exists(fname):
173+
raise FileNotFoundError(f"Dataset file not found: {fname}")
174+
return self._read_sparse_matrix(fname)
175+
176+
def read_query(self):
177+
"""
178+
Read the queries as csr_matrix.
179+
"""
180+
fname = os.path.join(self.dataset_location, self.qs_fn.replace(".gz", ""))
181+
if not os.path.exists(fname):
182+
raise FileNotFoundError(f"Query file not found: {fname}")
183+
return self._read_sparse_matrix(fname)
184+
185+
def read_groundtruth(self):
186+
"""
187+
Read the groundtruth file and return (I, D) as in knn_result_read.
188+
"""
189+
fname = os.path.join(self.dataset_location, self.gt_fn)
190+
if not os.path.exists(fname):
191+
raise FileNotFoundError(f"Groundtruth file not found: {fname}")
192+
return self._knn_result_read(fname)
193+
194+
def _read_sparse_matrix(self, fname):
195+
"""
196+
Internal: Read a CSR matrix in 'spmat'/sparse format.
197+
"""
198+
from scipy.sparse import csr_matrix
199+
with open(fname, "rb") as f:
200+
sizes = np.fromfile(f, dtype='int64', count=3)
201+
nrow, ncol, nnz = sizes
202+
indptr = np.fromfile(f, dtype='int64', count=nrow + 1)
203+
indices = np.fromfile(f, dtype='int32', count=nnz)
204+
data = np.fromfile(f, dtype='float32', count=nnz)
205+
return csr_matrix((data, indices, indptr), shape=(nrow, ncol))
206+
207+
def _knn_result_read(self, fname):
208+
"""
209+
Internal: Read the groundtruth as per scripts/sparse.py/knn_result_read.
210+
Returns (I, D) where I are the indices and D the distances.
211+
"""
212+
n, d = np.fromfile(fname, dtype="uint32", count=2)
213+
expected_size = 8 + n * d * (4 + 4)
214+
file_size = os.stat(fname).st_size
215+
assert file_size == expected_size, f"File size mismatch: expected {expected_size}, got {file_size}"
216+
with open(fname, "rb") as f:
217+
f.seek(8)
218+
I = np.fromfile(f, dtype="int32", count=n * d).reshape(n, d)
219+
D = np.fromfile(f, dtype="float32", count=n * d).reshape(n, d)
220+
return I, D
221+
110222
class SiftVector(object):
111223
def __init__(self):
112224
self.dataset = "siftsmall"
@@ -147,36 +259,83 @@ def encode_vector(self, vector, is_bigendian=False):
147259
endian = '<'
148260
buf = struct.pack(f'{endian}%sf' % len(vector), *vector)
149261
return base64.b64encode(buf).decode()
150-
def load_batch_documents(self, cluster, docs, batch, is_xattr=False, is_base64=False, is_bigendian=False, bucket='default', scope='_default', collection='_default', vector_field='vec'):
262+
263+
def encode_sparse_vector(self, indices, values, is_bigendian=False):
264+
# encodes two lists: indices (int) and values (float)
265+
if is_bigendian:
266+
endian = '>'
267+
else:
268+
endian = '<'
269+
idx_buf = struct.pack(f"{endian}{len(indices)}i", *indices)
270+
val_buf = struct.pack(f"{endian}{len(values)}f", *values)
271+
return [base64.b64encode(idx_buf).decode(), base64.b64encode(val_buf).decode()]
272+
273+
def load_batch_documents(
274+
self,
275+
cluster,
276+
docs,
277+
batch,
278+
is_xattr=False,
279+
is_base64=False,
280+
is_bigendian=False,
281+
vector_type='dense',
282+
bucket='default',
283+
scope='_default',
284+
collection='_default',
285+
vector_field='vec'
286+
):
287+
"""
288+
Loads batch docs into Couchbase, supporting both dense (default) and sparse.
289+
- If vector_type='dense', `docs` should be a sequence of ndarrays.
290+
- If vector_type='sparse', `docs` should be a sequence of 2-tuples/lists: (indices, values)
291+
"""
151292
cb = cluster.bucket(bucket)
152293
cb_coll = cb.scope(scope).collection(collection)
153294
documents = {}
154295
for is1, size in enumerate(cfg["sizes"]):
155296
for ib, brand in enumerate(cfg["brands"]):
156297
documents = {}
157298
for idx, x in enumerate(docs):
158-
vector = x.tolist()
159-
if is_base64:
160-
vector = self.encode_vector(vector, is_bigendian)
299+
if vector_type == 'sparse':
300+
# Expect x to be (indices, values) or a scipy.sparse row-like object
301+
if hasattr(x, "indices") and hasattr(x, "data"):
302+
indices = x.indices.tolist()
303+
values = x.data.tolist()
304+
elif isinstance(x, (tuple, list)) and len(x) == 2:
305+
indices, values = x
306+
else:
307+
raise ValueError("When vector_type='sparse', each doc must be (indices, values)")
308+
if is_base64:
309+
vector = self.encode_sparse_vector(indices, values, is_bigendian)
310+
else:
311+
vector = [indices, values]
312+
else:
313+
vector = x.tolist()
314+
if is_base64:
315+
vector = self.encode_vector(vector, is_bigendian)
161316
key = f"vec_{brand}_{size}_{idx+batch}"
162317
doc = {
163318
"id": idx + batch,
164-
"size":size,
165-
"sizeidx":is1,
166-
"brand":brand,
167-
"brandidx":ib,
319+
"size": size,
320+
"sizeidx": is1,
321+
"brand": brand,
322+
"brandidx": ib,
168323
vector_field: vector
169324
}
170-
# if is_xattr:
171-
# del doc[vector_field]
325+
# if is_xattr: remove vector_field from doc and store as xattr below
172326
documents[key] = doc
173327
try:
174-
upsert = cb_coll.upsert_multi(documents)
328+
cb_coll.upsert_multi(documents)
175329
except Exception as e:
176330
print(e)
177331
if is_xattr:
178332
for key in documents:
179-
cb_coll.mutate_in(key, [SD.upsert(vector_field, documents[key][vector_field], xattr=is_xattr), SD.remove(vector_field)])
333+
cb_coll.mutate_in(
334+
key,
335+
[SD.upsert(vector_field, documents[key][vector_field], xattr=True),
336+
SD.remove(vector_field)]
337+
)
338+
180339
def multi_upsert_document_into_cb(self, cb_coll, documents):
181340
try:
182341
cb_coll.upsert_multi(documents)

pytests/tuqquery/tuq_vectorsearch.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from couchbase.auth import PasswordAuthenticator
99
from couchbase.cluster import Cluster, ClusterOptions
1010
from lib.vector.vector import SiftVector as sift, FAISSVector as faiss
11+
from lib.vector.vector import SparseVector as sparse
1112
from lib.vector.vector import LoadVector, QueryVector, UtilVector, IndexVector
1213

1314
class VectorSearchTests(QueryTests):
@@ -32,13 +33,22 @@ def setUp(self):
3233
self.use_bhive = self.input.param("use_bhive", False)
3334
self.rerank = self.input.param("rerank", False)
3435
self.train = self.input.param("train", 10000)
36+
self.vector_type = self.input.param("vector_type", "dense")
3537
auth = PasswordAuthenticator(self.master.rest_username, self.master.rest_password)
3638
self.database = Cluster(f'couchbase://{self.master.ip}', ClusterOptions(auth))
3739
# Get dataset
38-
sift().download_sift()
39-
self.xb = sift().read_base()
40-
self.xq = sift().read_query()
41-
self.gt = sift().read_groundtruth()
40+
if self.vector_type == "dense":
41+
sift().download_sift()
42+
self.xb = sift().read_base()
43+
self.xq = sift().read_query()
44+
self.gt = sift().read_groundtruth()
45+
elif self.vector_type == "sparse":
46+
sparse().download_dataset()
47+
self.xb = sparse().read_vector()
48+
self.xq = sparse().read_query()
49+
self.gt = sparse().read_groundtruth()
50+
else:
51+
self.fail(f"Invalid vector type: {self.vector_type}")
4252
# Extend dimension beyond 128
4353
if self.dimension > 128:
4454
add_dimension = self.dimension - 128
@@ -52,8 +62,8 @@ def suite_setUp(self):
5262
super(VectorSearchTests, self).suite_setUp()
5363
threads = []
5464
self.log.info("Start loading vector data")
55-
for i in range(0, len(self.xb), 1000): # load in batches of 1000 docs per thread
56-
thread = threading.Thread(target=LoadVector().load_batch_documents,args=(self.database, self.xb[i:i+1000], i, self.use_xattr, self.use_base64, self.use_bigendian))
65+
for i in range(0, self.xb.shape[0], 1000): # load in batches of 1000 docs per thread
66+
thread = threading.Thread(target=LoadVector().load_batch_documents,args=(self.database, self.xb[i:i+1000], i, self.use_xattr, self.use_base64, self.use_bigendian, self.vector_type))
5767
threads.append(thread)
5868
# Start threads
5969
for i in range(len(threads)):

0 commit comments

Comments
 (0)