Skip to content

Commit 6b0a4c6

Browse files
committed
add migration script
1 parent 63ffb81 commit 6b0a4c6

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

scripts/migrate_old_model.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Script to migrate old pvnet-summation models (v0.3.7) which are hosted on huggingface to current
2+
version"""
3+
import datetime
4+
import os
5+
from importlib.metadata import version
6+
import tempfile
7+
8+
import torch
9+
import yaml
10+
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi
11+
from safetensors.torch import save_file
12+
13+
from pvnet_summation.models import BaseModel
14+
from pvnet_summation.utils import (
15+
MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATAMODULE_CONFIG_NAME,
16+
)
17+
18+
# ------------------------------------------
19+
# USER SETTINGS
20+
21+
# The huggingface commit of the model you want to update
22+
repo_id = "openclimatefix/pvnet_v2_summation"
23+
revision = "175a71206cf89a2d8fcd180cfa60d132590f12cb"
24+
25+
# The local directory which will be downloaded to
26+
local_dir = "/home/jamesfulton/tmp/sum_model_migration"
27+
28+
# Whether to upload the migrated model back to the huggingface
29+
upload = False
30+
31+
# ------------------------------------------
32+
# SETUP
33+
34+
os.makedirs(local_dir, exist_ok=False)
35+
36+
# Set up huggingface API
37+
api = HfApi()
38+
39+
# Download the model repo
40+
_ = api.snapshot_download(
41+
repo_id=repo_id,
42+
revision=revision,
43+
local_dir=local_dir,
44+
force_download=True,
45+
)
46+
47+
# ------------------------------------------
48+
# MIGRATION STEPS
49+
50+
# Modify the model config
51+
with open(f"{local_dir}/{MODEL_CONFIG_NAME}") as cfg:
52+
model_config = yaml.load(cfg, Loader=yaml.FullLoader)
53+
54+
# Get the PVNet model it was trained on
55+
pvnet_model_id = model_config.pop("model_name")
56+
pvnet_revision = model_config.pop("model_version")
57+
58+
59+
with tempfile.TemporaryDirectory() as pvnet_dir_dir:
60+
61+
# Download the model repo
62+
_ = api.snapshot_download(
63+
repo_id=pvnet_model_id,
64+
revision=pvnet_revision,
65+
local_dir=str(pvnet_dir_dir),
66+
force_download=True,
67+
)
68+
69+
with open(f"{pvnet_dir_dir}/{MODEL_CONFIG_NAME}") as cfg:
70+
pvnet_model_config = yaml.load(cfg, Loader=yaml.FullLoader)
71+
72+
73+
# Get rid of the optimiser - we don't store this anymore
74+
del model_config["optimizer"]
75+
76+
# Rename the top level model
77+
if model_config["_target_"]=="pvnet_summation.models.flat_model.FlatModel":
78+
model_config["_target_"] = "pvnet_summation.models.dense_model.DenseModel"
79+
else:
80+
raise Exception("Unknown model: " + model_config["_target_"])
81+
82+
# Models which used this setting are not supported any more
83+
if model_config["relative_scale_pvnet_outputs"]:
84+
raise Exception("Models with `relative_scale_pvnet_outputs=True` are no longer supported")
85+
else:
86+
del model_config["relative_scale_pvnet_outputs"]
87+
88+
89+
model_config["num_input_locations"] = model_config.pop("num_locations")
90+
91+
# Re-find the model components in the new PVNet package structure
92+
model_config["output_network"]["_target_"] = (
93+
model_config["output_network"]["_target_"]
94+
.replace("multimodal", "late_fusion")
95+
.replace("ResFCNet2", "ResFCNet")
96+
)
97+
98+
# Add entries from the PVNet model which are now required in the summation model
99+
model_config["history_minutes"] = pvnet_model_config["history_minutes"]
100+
model_config["forecast_minutes"] = pvnet_model_config["forecast_minutes"]
101+
model_config["interval_minutes"] = pvnet_model_config.get("interval_minutes", 30)
102+
model_config["input_quantiles"] = pvnet_model_config["output_quantiles"]
103+
104+
# Save the model config
105+
with open(f"{local_dir}/{MODEL_CONFIG_NAME}", "w") as f:
106+
yaml.dump(model_config, f, sort_keys=False, default_flow_style=False)
107+
108+
# Create a datamodule
109+
with open(f"{local_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f:
110+
datamodule = {"pvnet_model": {"model_id": pvnet_model_id, "revision": pvnet_revision}}
111+
yaml.dump(datamodule, f, sort_keys=False, default_flow_style=False)
112+
113+
# Resave the model weights as safetensors and remove the PVNet weights which we no longer need
114+
state_dict = torch.load(f"{local_dir}/pytorch_model.bin", map_location="cpu", weights_only=True)
115+
new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("pvnet_model")}
116+
save_file(new_state_dict, f"{local_dir}/{PYTORCH_WEIGHTS_NAME}")
117+
os.remove(f"{local_dir}/pytorch_model.bin")
118+
119+
# Add a note to the model card to say the model has been migrated
120+
with open(f"{local_dir}/{MODEL_CARD_NAME}", "a") as f:
121+
current_date = datetime.date.today().strftime("%Y-%m-%d")
122+
summation_version = version("pvnet_summation")
123+
f.write(
124+
f"\n\n---\n**Migration Note**: This model was migrated on {current_date} "
125+
f"to pvnet-summation version {summation_version}\n"
126+
)
127+
128+
# ------------------------------------------
129+
# CHECKS
130+
131+
# Check the model can be loaded
132+
model = BaseModel.from_pretrained(model_id=local_dir, revision=None)
133+
134+
print("Model checkpoint successfully migrated")
135+
136+
# ------------------------------------------
137+
# UPLOAD TO HUGGINGFACE
138+
139+
if upload:
140+
print("Uploading migrated model to huggingface")
141+
142+
operations = []
143+
for file in [MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATAMODULE_CONFIG_NAME]:
144+
# Stage modified files for upload
145+
operations.append(
146+
CommitOperationAdd(
147+
path_in_repo=file, # Name of the file in the repo
148+
path_or_fileobj=f"{local_dir}/{file}", # Local path to the file
149+
),
150+
)
151+
152+
operations.append(
153+
# Remove old pytorch weights file
154+
CommitOperationDelete(path_in_repo="pytorch_model.bin")
155+
)
156+
157+
commit_info = api.create_commit(
158+
repo_id=repo_id,
159+
operations=operations,
160+
commit_message=f"Migrate model to pvnet-summation version {summation_version}",
161+
)
162+
163+
# Print the most recent commit hash
164+
c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0]
165+
166+
print(
167+
f"\nThe latest commit is now: \n"
168+
f" date: {c.created_at} \n"
169+
f" commit hash: {c.commit_id}\n"
170+
f" by: {c.authors}\n"
171+
f" title: {c.title}\n"
172+
)

0 commit comments

Comments
 (0)