Skip to content

Commit 94a5f74

Browse files
authored
Merge pull request #199 from databio/bug_fix_r2v_scembed
fix tokenizer creation issue with ScEmbed (#198)
2 parents 11ce521 + ded5b72 commit 94a5f74

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,5 @@ qdrant_storage/
197197

198198
local_cache
199199

200-
lightning_logs
200+
lightning_logs
201+
data/

geniml/region2vec/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str):
140140
:param str model_path: Path to the model checkpoint.
141141
:param str vocab_path: Path to the vocabulary file.
142142
"""
143-
_model, tokenizer, config = load_local_region2vec_model(
144-
model_path, vocab_path, config_path
145-
)
143+
_model, config = load_local_region2vec_model(model_path, config_path)
144+
tokenizer = TreeTokenizer(vocab_path)
145+
146146
self._model = _model
147147
self.tokenizer = tokenizer
148+
148149
self.trained = True
149150
if POOLING_METHOD_KEY in config:
150151
self.pooling_method = config[POOLING_METHOD_KEY]

geniml/region2vec/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,15 @@ def export_region2vec_model(
452452

453453
def load_local_region2vec_model(
454454
model_path: str,
455-
vocab_path: str,
456455
config_path: str,
457-
) -> Tuple[Region2Vec, TreeTokenizer, dict]:
456+
) -> Tuple[Region2Vec, dict]:
458457
"""
459458
Load a region2vec model from a local directory
460459
461460
:param str model_path: The path to the model checkpoint file
462461
:param str config_path: The path to the model config file
463462
:param str vocab_path: The path to the model vocabulary file
464463
"""
465-
# init the tokenizer - only one option for now
466-
tokenizer = TreeTokenizer(vocab_path)
467464

468465
# load the model state dict (weights)
469466
params = torch.load(model_path)
@@ -491,7 +488,7 @@ def load_local_region2vec_model(
491488

492489
model.load_state_dict(params)
493490

494-
return model, tokenizer, config
491+
return model, config
495492

496493

497494
class Region2VecDataset:

geniml/scembed/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str):
131131
:param str model_path: Path to the model checkpoint.
132132
:param str vocab_path: Path to the vocabulary file.
133133
"""
134-
_model, tokenizer, config = load_local_region2vec_model(
135-
model_path, vocab_path, config_path
136-
)
134+
_model, config = load_local_region2vec_model(model_path, config_path)
135+
tokenizer = AnnDataTokenizer(vocab_path, verbose=True)
136+
137137
self._model = _model
138138
self.tokenizer = tokenizer
139+
139140
if POOLING_METHOD_KEY in config:
140141
self.pooling_method = config[POOLING_METHOD_KEY]
141142

0 commit comments

Comments
 (0)