23
23
)
24
24
25
25
26
+ def santize_datamodule (config : dict ) -> dict :
27
+ """Create new datamodule config which only keeps the details required for inference"""
28
+ return {"pvnet_model" : config ["pvnet_model" ]}
29
+
30
+
26
31
def download_from_hf (
27
32
repo_id : str ,
28
33
filename : str | list [str ],
@@ -114,6 +119,32 @@ def from_pretrained(
114
119
model .eval () # type: ignore
115
120
116
121
return model
122
+
123
+ @classmethod
124
+ def get_datamodule_config (
125
+ cls ,
126
+ model_id : str ,
127
+ revision : str ,
128
+ cache_dir : str | None = None ,
129
+ force_download : bool = False ,
130
+ ) -> str :
131
+ """Load data config file."""
132
+ if os .path .isdir (model_id ):
133
+ print ("Loading datamodule config from local directory" )
134
+ datamodule_config_file = os .path .join (model_id , DATAMODULE_CONFIG_NAME )
135
+ else :
136
+ print ("Loading datamodule config from huggingface repo" )
137
+ datamodule_config_file = download_from_hf (
138
+ repo_id = model_id ,
139
+ filename = DATAMODULE_CONFIG_NAME ,
140
+ revision = revision ,
141
+ cache_dir = cache_dir ,
142
+ force_download = force_download ,
143
+ max_retries = 5 ,
144
+ wait_time = 10 ,
145
+ )
146
+
147
+ return datamodule_config_file
117
148
118
149
def _save_model_weights (self , save_directory : str ) -> None :
119
150
"""Save weights from a Pytorch model to a local directory."""
@@ -126,7 +157,7 @@ def save_pretrained(
126
157
wandb_repo : str ,
127
158
wandb_id : str ,
128
159
card_template_path : str ,
129
- datamodule_config_path : str | None = None ,
160
+ datamodule_config_path ,
130
161
experiment_config_path : str | None = None ,
131
162
hf_repo_id : str | None = None ,
132
163
push_to_hub : bool = False ,
@@ -142,16 +173,15 @@ def save_pretrained(
142
173
wandb_id: Identifier of the model on wandb.
143
174
datamodule_config_path:
144
175
The path to the datamodule config.
176
+ card_template_path: Path to the HuggingFace model card template. Defaults to card in
177
+ PVNet library if set to None.
145
178
experiment_config_path:
146
179
The path to the full experimental config.
147
180
hf_repo_id:
148
181
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
149
182
the folder name if not provided.
150
183
push_to_hub (`bool`, *optional*, defaults to `False`):
151
184
Whether or not to push your model to the HuggingFace Hub after saving it.
152
-
153
- card_template_path: Path to the HuggingFace model card template. Defaults to card in
154
- PVNet library if set to None.
155
185
"""
156
186
157
187
save_directory = Path (save_directory )
@@ -165,9 +195,14 @@ def save_pretrained(
165
195
with open (save_directory / MODEL_CONFIG_NAME , "w" ) as outfile :
166
196
yaml .dump (model_config , outfile , sort_keys = False , default_flow_style = False )
167
197
168
- # Save the datamodule config
169
- if datamodule_config_path is not None :
170
- shutil .copyfile (datamodule_config_path , save_directory / DATAMODULE_CONFIG_NAME )
198
+ # Sanitize and save the datamodule config
199
+ with open (datamodule_config_path ) as cfg :
200
+ datamodule_config = yaml .load (cfg , Loader = yaml .FullLoader )
201
+
202
+ datamodule_config = santize_datamodule (datamodule_config )
203
+
204
+ with open (save_directory / DATAMODULE_CONFIG_NAME , "w" ) as outfile :
205
+ yaml .dump (datamodule_config , outfile , sort_keys = False , default_flow_style = False )
171
206
172
207
# Save the full experimental config
173
208
if experiment_config_path is not None :
0 commit comments