Skip to content

Commit 36f3523

Browse files
committed
Fix migration script
1 parent cd7464d commit 36f3523

File tree

1 file changed

+41
-26
lines changed

1 file changed

+41
-26
lines changed

scripts/migrate_old_model.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import os
66
import tempfile
77
from importlib.metadata import version
8+
import tempfile
89

910
import torch
1011
import yaml
11-
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi
12+
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, file_exists
1213
from safetensors.torch import save_file
1314

1415
from pvnet_summation.models import BaseModel
@@ -23,19 +24,31 @@
2324
# USER SETTINGS
2425

2526
# The huggingface commit of the model you want to update
26-
repo_id = "openclimatefix/pvnet_v2_summation"
27-
revision = "175a71206cf89a2d8fcd180cfa60d132590f12cb"
27+
repo_id: str = "openclimatefix/pvnet_v2_summation"
28+
revision: str = "175a71206cf89a2d8fcd180cfa60d132590f12cb"
29+
30+
# If set to a string, this allows you to change the commit of the PVNet model which the summation
31+
# model goes with. Else the commit hash is taken from what is already stored on huggingface.
32+
# This is useful if both PVNet and PVNet-summation are being migrated simultaneously
33+
pvnet_revision: str | None = None
2834

2935
# The local directory which will be downloaded to
30-
local_dir = "/home/jamesfulton/tmp/sum_model_migration"
36+
# If set to None a temporary directory will be used
37+
local_dir: str | None = None
3138

3239
# Whether to upload the migrated model back to the huggingface
33-
upload = False
40+
upload: bool = True
3441

3542
# ------------------------------------------
3643
# SETUP
3744

38-
os.makedirs(local_dir, exist_ok=False)
45+
if local_dir is None:
46+
temp_dir = tempfile.TemporaryDirectory()
47+
save_dir = temp_dir.name
48+
49+
else:
50+
os.makedirs(local_dir, exist_ok=False)
51+
save_dir = local_dir
3952

4053
# Set up huggingface API
4154
api = HfApi()
@@ -44,36 +57,37 @@
4457
_ = api.snapshot_download(
4558
repo_id=repo_id,
4659
revision=revision,
47-
local_dir=local_dir,
60+
local_dir=save_dir,
4861
force_download=True,
4962
)
5063

5164
# ------------------------------------------
5265
# MIGRATION STEPS
5366

5467
# Modify the model config
55-
with open(f"{local_dir}/{MODEL_CONFIG_NAME}") as cfg:
68+
with open(f"{save_dir}/{MODEL_CONFIG_NAME}") as cfg:
5669
model_config = yaml.load(cfg, Loader=yaml.FullLoader)
5770

5871
# Get the PVNet model it was trained on
5972
pvnet_model_id = model_config.pop("model_name")
60-
pvnet_revision = model_config.pop("model_version")
73+
if pvnet_revision is None:
74+
pvnet_revision = model_config["model_version"]
75+
del model_config["model_version"]
6176

6277

63-
with tempfile.TemporaryDirectory() as pvnet_dir_dir:
78+
with tempfile.TemporaryDirectory() as pvnet_dir:
6479

6580
# Download the model repo
6681
_ = api.snapshot_download(
6782
repo_id=pvnet_model_id,
6883
revision=pvnet_revision,
69-
local_dir=str(pvnet_dir_dir),
84+
local_dir=str(pvnet_dir),
7085
force_download=True,
7186
)
7287

73-
with open(f"{pvnet_dir_dir}/{MODEL_CONFIG_NAME}") as cfg:
88+
with open(f"{pvnet_dir}/{MODEL_CONFIG_NAME}") as cfg:
7489
pvnet_model_config = yaml.load(cfg, Loader=yaml.FullLoader)
7590

76-
7791
# Get rid of the optimiser - we don't store this anymore
7892
del model_config["optimizer"]
7993

@@ -106,22 +120,22 @@
106120
model_config["input_quantiles"] = pvnet_model_config["output_quantiles"]
107121

108122
# Save the model config
109-
with open(f"{local_dir}/{MODEL_CONFIG_NAME}", "w") as f:
123+
with open(f"{save_dir}/{MODEL_CONFIG_NAME}", "w") as f:
110124
yaml.dump(model_config, f, sort_keys=False, default_flow_style=False)
111125

112126
# Create a datamodule
113-
with open(f"{local_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f:
127+
with open(f"{save_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f:
114128
datamodule = {"pvnet_model": {"model_id": pvnet_model_id, "revision": pvnet_revision}}
115129
yaml.dump(datamodule, f, sort_keys=False, default_flow_style=False)
116130

117131
# Resave the model weights as safetensors and remove the PVNet weights which we no longer need
118-
state_dict = torch.load(f"{local_dir}/pytorch_model.bin", map_location="cpu", weights_only=True)
132+
state_dict = torch.load(f"{save_dir}/pytorch_model.bin", map_location="cpu", weights_only=True)
119133
new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("pvnet_model")}
120-
save_file(new_state_dict, f"{local_dir}/{PYTORCH_WEIGHTS_NAME}")
121-
os.remove(f"{local_dir}/pytorch_model.bin")
134+
save_file(new_state_dict, f"{save_dir}/{PYTORCH_WEIGHTS_NAME}")
135+
os.remove(f"{save_dir}/pytorch_model.bin")
122136

123137
# Add a note to the model card to say the model has been migrated
124-
with open(f"{local_dir}/{MODEL_CARD_NAME}", "a") as f:
138+
with open(f"{save_dir}/{MODEL_CARD_NAME}", "a") as f:
125139
current_date = datetime.date.today().strftime("%Y-%m-%d")
126140
summation_version = version("pvnet_summation")
127141
f.write(
@@ -133,7 +147,7 @@
133147
# CHECKS
134148

135149
# Check the model can be loaded
136-
model = BaseModel.from_pretrained(model_id=local_dir, revision=None)
150+
model = BaseModel.from_pretrained(model_id=save_dir, revision=None)
137151

138152
print("Model checkpoint successfully migrated")
139153

@@ -149,14 +163,15 @@
149163
operations.append(
150164
CommitOperationAdd(
151165
path_in_repo=file, # Name of the file in the repo
152-
path_or_fileobj=f"{local_dir}/{file}", # Local path to the file
166+
path_or_fileobj=f"{save_dir}/{file}", # Local path to the file
153167
),
154168
)
155-
156-
operations.append(
157-
# Remove old pytorch weights file
158-
CommitOperationDelete(path_in_repo="pytorch_model.bin")
159-
)
169+
170+
# Remove old pytorch weights file if it exists in the most recent commit
171+
if file_exists(repo_id, "pytorch_model.bin"):
172+
operations.append(
173+
CommitOperationDelete(path_in_repo="pytorch_model.bin")
174+
)
160175

161176
commit_info = api.create_commit(
162177
repo_id=repo_id,

0 commit comments

Comments
 (0)