Skip to content

Commit 036c2cb

Browse files
committed
Export datamodule to huggingface
1 parent e1b626f commit 036c2cb

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

pvnet_summation/models/base_model.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
)
2424

2525

26+
def santize_datamodule(config: dict) -> dict:
27+
"""Create new datamodule config which only keeps the details required for inference"""
28+
return {"pvnet_model": config["pvnet_model"]}
29+
30+
2631
def download_from_hf(
2732
repo_id: str,
2833
filename: str | list[str],
@@ -114,6 +119,32 @@ def from_pretrained(
114119
model.eval() # type: ignore
115120

116121
return model
122+
123+
@classmethod
124+
def get_datamodule_config(
125+
cls,
126+
model_id: str,
127+
revision: str,
128+
cache_dir: str | None = None,
129+
force_download: bool = False,
130+
) -> str:
131+
"""Load data config file."""
132+
if os.path.isdir(model_id):
133+
print("Loading datamodule config from local directory")
134+
datamodule_config_file = os.path.join(model_id, DATAMODULE_CONFIG_NAME)
135+
else:
136+
print("Loading datamodule config from huggingface repo")
137+
datamodule_config_file = download_from_hf(
138+
repo_id=model_id,
139+
filename=DATAMODULE_CONFIG_NAME,
140+
revision=revision,
141+
cache_dir=cache_dir,
142+
force_download=force_download,
143+
max_retries=5,
144+
wait_time=10,
145+
)
146+
147+
return datamodule_config_file
117148

118149
def _save_model_weights(self, save_directory: str) -> None:
119150
"""Save weights from a Pytorch model to a local directory."""
@@ -126,7 +157,7 @@ def save_pretrained(
126157
wandb_repo: str,
127158
wandb_id: str,
128159
card_template_path: str,
129-
datamodule_config_path: str | None = None,
160+
datamodule_config_path,
130161
experiment_config_path: str | None = None,
131162
hf_repo_id: str | None = None,
132163
push_to_hub: bool = False,
@@ -142,16 +173,15 @@ def save_pretrained(
142173
wandb_id: Identifier of the model on wandb.
143174
datamodule_config_path:
144175
The path to the datamodule config.
176+
card_template_path: Path to the HuggingFace model card template. Defaults to card in
177+
PVNet library if set to None.
145178
experiment_config_path:
146179
The path to the full experimental config.
147180
hf_repo_id:
148181
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
149182
the folder name if not provided.
150183
push_to_hub (`bool`, *optional*, defaults to `False`):
151184
Whether or not to push your model to the HuggingFace Hub after saving it.
152-
153-
card_template_path: Path to the HuggingFace model card template. Defaults to card in
154-
PVNet library if set to None.
155185
"""
156186

157187
save_directory = Path(save_directory)
@@ -165,9 +195,14 @@ def save_pretrained(
165195
with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile:
166196
yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False)
167197

168-
# Save the datamodule config
169-
if datamodule_config_path is not None:
170-
shutil.copyfile(datamodule_config_path, save_directory / DATAMODULE_CONFIG_NAME)
198+
# Sanitize and save the datamodule config
199+
with open(datamodule_config_path) as cfg:
200+
datamodule_config = yaml.load(cfg, Loader=yaml.FullLoader)
201+
202+
datamodule_config = santize_datamodule(datamodule_config)
203+
204+
with open(save_directory / DATAMODULE_CONFIG_NAME, "w") as outfile:
205+
yaml.dump(datamodule_config, outfile, sort_keys=False, default_flow_style=False)
171206

172207
# Save the full experimental config
173208
if experiment_config_path is not None:

pvnet_summation/training/train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from pvnet_summation.data.datamodule import PresavedDataModule, StreamedDataModule
16-
from pvnet_summation.utils import MODEL_CONFIG_NAME
16+
from pvnet_summation.utils import MODEL_CONFIG_NAME, FULL_CONFIG_NAME, DATAMODULE_CONFIG_NAME
1717

1818
log = logging.getLogger(__name__)
1919

@@ -163,6 +163,14 @@ def train(config: DictConfig) -> None:
163163
os.makedirs(save_dir, exist_ok=True)
164164
OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}")
165165

166+
# Save the datamodule config
167+
OmegaConf.save(config.datamodule, f"{save_dir}/{DATAMODULE_CONFIG_NAME}")
168+
169+
# Save the full hydra config to the output directory and to wandb
170+
OmegaConf.save(config, f"{save_dir}/{FULL_CONFIG_NAME}")
171+
wandb_logger.experiment.save(f"{save_dir}/{FULL_CONFIG_NAME}", base_path=save_dir)
172+
173+
166174
# Init lightning model
167175
model = hydra.utils.instantiate(config.model)
168176

0 commit comments

Comments
 (0)