|
5 | 5 | import os
|
6 | 6 | import tempfile
|
7 | 7 | from importlib.metadata import version
|
| 8 | +import tempfile |
8 | 9 |
|
9 | 10 | import torch
|
10 | 11 | import yaml
|
11 |
| -from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi |
| 12 | +from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, file_exists |
12 | 13 | from safetensors.torch import save_file
|
13 | 14 |
|
14 | 15 | from pvnet_summation.models import BaseModel
|
|
23 | 24 | # USER SETTINGS
|
24 | 25 |
|
25 | 26 | # The huggingface commit of the model you want to update
|
26 |
| -repo_id = "openclimatefix/pvnet_v2_summation" |
27 |
| -revision = "175a71206cf89a2d8fcd180cfa60d132590f12cb" |
| 27 | +repo_id: str = "openclimatefix/pvnet_v2_summation" |
| 28 | +revision: str = "175a71206cf89a2d8fcd180cfa60d132590f12cb" |
| 29 | + |
| 30 | +# If set to a string, this allows you to change the commit of the PVNet model which the summation |
| 31 | +# model goes with. Else the commit hash is taken from what is already stored on huggingface. |
| 32 | +# This is useful if both PVNet and PVNet-summation are being migrated simultaneously |
| 33 | +pvnet_revision: str | None = None |
28 | 34 |
|
29 | 35 | # The local directory which will be downloaded to
|
30 |
| -local_dir = "/home/jamesfulton/tmp/sum_model_migration" |
| 36 | +# If set to None a temporary directory will be used |
| 37 | +local_dir: str | None = None |
31 | 38 |
|
32 | 39 | # Whether to upload the migrated model back to the huggingface
|
33 |
| -upload = False |
| 40 | +upload: bool = True |
34 | 41 |
|
35 | 42 | # ------------------------------------------
|
36 | 43 | # SETUP
|
37 | 44 |
|
38 |
| -os.makedirs(local_dir, exist_ok=False) |
| 45 | +if local_dir is None: |
| 46 | + temp_dir = tempfile.TemporaryDirectory() |
| 47 | + save_dir = temp_dir.name |
| 48 | + |
| 49 | +else: |
| 50 | + os.makedirs(local_dir, exist_ok=False) |
| 51 | + save_dir = local_dir |
39 | 52 |
|
40 | 53 | # Set up huggingface API
|
41 | 54 | api = HfApi()
|
|
44 | 57 | _ = api.snapshot_download(
|
45 | 58 | repo_id=repo_id,
|
46 | 59 | revision=revision,
|
47 |
| - local_dir=local_dir, |
| 60 | + local_dir=save_dir, |
48 | 61 | force_download=True,
|
49 | 62 | )
|
50 | 63 |
|
51 | 64 | # ------------------------------------------
|
52 | 65 | # MIGRATION STEPS
|
53 | 66 |
|
54 | 67 | # Modify the model config
|
55 |
| -with open(f"{local_dir}/{MODEL_CONFIG_NAME}") as cfg: |
| 68 | +with open(f"{save_dir}/{MODEL_CONFIG_NAME}") as cfg: |
56 | 69 | model_config = yaml.load(cfg, Loader=yaml.FullLoader)
|
57 | 70 |
|
58 | 71 | # Get the PVNet model it was trained on
|
59 | 72 | pvnet_model_id = model_config.pop("model_name")
|
60 |
| -pvnet_revision = model_config.pop("model_version") |
| 73 | +if pvnet_revision is None: |
| 74 | + pvnet_revision = model_config["model_version"] |
| 75 | +del model_config["model_version"] |
61 | 76 |
|
62 | 77 |
|
63 |
| -with tempfile.TemporaryDirectory() as pvnet_dir_dir: |
| 78 | +with tempfile.TemporaryDirectory() as pvnet_dir: |
64 | 79 |
|
65 | 80 | # Download the model repo
|
66 | 81 | _ = api.snapshot_download(
|
67 | 82 | repo_id=pvnet_model_id,
|
68 | 83 | revision=pvnet_revision,
|
69 |
| - local_dir=str(pvnet_dir_dir), |
| 84 | + local_dir=str(pvnet_dir), |
70 | 85 | force_download=True,
|
71 | 86 | )
|
72 | 87 |
|
73 |
| - with open(f"{pvnet_dir_dir}/{MODEL_CONFIG_NAME}") as cfg: |
| 88 | + with open(f"{pvnet_dir}/{MODEL_CONFIG_NAME}") as cfg: |
74 | 89 | pvnet_model_config = yaml.load(cfg, Loader=yaml.FullLoader)
|
75 | 90 |
|
76 |
| - |
77 | 91 | # Get rid of the optimiser - we don't store this anymore
|
78 | 92 | del model_config["optimizer"]
|
79 | 93 |
|
|
106 | 120 | model_config["input_quantiles"] = pvnet_model_config["output_quantiles"]
|
107 | 121 |
|
108 | 122 | # Save the model config
|
109 |
| -with open(f"{local_dir}/{MODEL_CONFIG_NAME}", "w") as f: |
| 123 | +with open(f"{save_dir}/{MODEL_CONFIG_NAME}", "w") as f: |
110 | 124 | yaml.dump(model_config, f, sort_keys=False, default_flow_style=False)
|
111 | 125 |
|
112 | 126 | # Create a datamodule
|
113 |
| -with open(f"{local_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f: |
| 127 | +with open(f"{save_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f: |
114 | 128 | datamodule = {"pvnet_model": {"model_id": pvnet_model_id, "revision": pvnet_revision}}
|
115 | 129 | yaml.dump(datamodule, f, sort_keys=False, default_flow_style=False)
|
116 | 130 |
|
117 | 131 | # Resave the model weights as safetensors and remove the PVNet weights which we no longer need
|
118 |
| -state_dict = torch.load(f"{local_dir}/pytorch_model.bin", map_location="cpu", weights_only=True) |
| 132 | +state_dict = torch.load(f"{save_dir}/pytorch_model.bin", map_location="cpu", weights_only=True) |
119 | 133 | new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("pvnet_model")}
|
120 |
| -save_file(new_state_dict, f"{local_dir}/{PYTORCH_WEIGHTS_NAME}") |
121 |
| -os.remove(f"{local_dir}/pytorch_model.bin") |
| 134 | +save_file(new_state_dict, f"{save_dir}/{PYTORCH_WEIGHTS_NAME}") |
| 135 | +os.remove(f"{save_dir}/pytorch_model.bin") |
122 | 136 |
|
123 | 137 | # Add a note to the model card to say the model has been migrated
|
124 |
| -with open(f"{local_dir}/{MODEL_CARD_NAME}", "a") as f: |
| 138 | +with open(f"{save_dir}/{MODEL_CARD_NAME}", "a") as f: |
125 | 139 | current_date = datetime.date.today().strftime("%Y-%m-%d")
|
126 | 140 | summation_version = version("pvnet_summation")
|
127 | 141 | f.write(
|
|
133 | 147 | # CHECKS
|
134 | 148 |
|
135 | 149 | # Check the model can be loaded
|
136 |
| -model = BaseModel.from_pretrained(model_id=local_dir, revision=None) |
| 150 | +model = BaseModel.from_pretrained(model_id=save_dir, revision=None) |
137 | 151 |
|
138 | 152 | print("Model checkpoint successfully migrated")
|
139 | 153 |
|
|
149 | 163 | operations.append(
|
150 | 164 | CommitOperationAdd(
|
151 | 165 | path_in_repo=file, # Name of the file in the repo
|
152 |
| - path_or_fileobj=f"{local_dir}/{file}", # Local path to the file |
| 166 | + path_or_fileobj=f"{save_dir}/{file}", # Local path to the file |
153 | 167 | ),
|
154 | 168 | )
|
155 |
| - |
156 |
| - operations.append( |
157 |
| - # Remove old pytorch weights file |
158 |
| - CommitOperationDelete(path_in_repo="pytorch_model.bin") |
159 |
| - ) |
| 169 | + |
| 170 | + # Remove old pytorch weights file if it exists in the most recent commit |
| 171 | + if file_exists(repo_id, "pytorch_model.bin"): |
| 172 | + operations.append( |
| 173 | + CommitOperationDelete(path_in_repo="pytorch_model.bin") |
| 174 | + ) |
160 | 175 |
|
161 | 176 | commit_info = api.create_commit(
|
162 | 177 | repo_id=repo_id,
|
|
0 commit comments