|
| 1 | +"""Script to migrate old pvnet-summation models (v0.3.7) which are hosted on huggingface to current |
| 2 | +version""" |
| 3 | +import datetime |
| 4 | +import os |
| 5 | +from importlib.metadata import version |
| 6 | +import tempfile |
| 7 | + |
| 8 | +import torch |
| 9 | +import yaml |
| 10 | +from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi |
| 11 | +from safetensors.torch import save_file |
| 12 | + |
| 13 | +from pvnet_summation.models import BaseModel |
| 14 | +from pvnet_summation.utils import ( |
| 15 | + MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATAMODULE_CONFIG_NAME, |
| 16 | +) |
| 17 | + |
| 18 | +# ------------------------------------------ |
| 19 | +# USER SETTINGS |
| 20 | + |
| 21 | +# The huggingface commit of the model you want to update |
| 22 | +repo_id = "openclimatefix/pvnet_v2_summation" |
| 23 | +revision = "175a71206cf89a2d8fcd180cfa60d132590f12cb" |
| 24 | + |
| 25 | +# The local directory which will be downloaded to |
| 26 | +local_dir = "/home/jamesfulton/tmp/sum_model_migration" |
| 27 | + |
| 28 | +# Whether to upload the migrated model back to the huggingface |
| 29 | +upload = False |
| 30 | + |
| 31 | +# ------------------------------------------ |
| 32 | +# SETUP |
| 33 | + |
| 34 | +os.makedirs(local_dir, exist_ok=False) |
| 35 | + |
| 36 | +# Set up huggingface API |
| 37 | +api = HfApi() |
| 38 | + |
| 39 | +# Download the model repo |
| 40 | +_ = api.snapshot_download( |
| 41 | + repo_id=repo_id, |
| 42 | + revision=revision, |
| 43 | + local_dir=local_dir, |
| 44 | + force_download=True, |
| 45 | +) |
| 46 | + |
| 47 | +# ------------------------------------------ |
| 48 | +# MIGRATION STEPS |
| 49 | + |
| 50 | +# Modify the model config |
| 51 | +with open(f"{local_dir}/{MODEL_CONFIG_NAME}") as cfg: |
| 52 | + model_config = yaml.load(cfg, Loader=yaml.FullLoader) |
| 53 | + |
| 54 | +# Get the PVNet model it was trained on |
| 55 | +pvnet_model_id = model_config.pop("model_name") |
| 56 | +pvnet_revision = model_config.pop("model_version") |
| 57 | + |
| 58 | + |
| 59 | +with tempfile.TemporaryDirectory() as pvnet_dir_dir: |
| 60 | + |
| 61 | + # Download the model repo |
| 62 | + _ = api.snapshot_download( |
| 63 | + repo_id=pvnet_model_id, |
| 64 | + revision=pvnet_revision, |
| 65 | + local_dir=str(pvnet_dir_dir), |
| 66 | + force_download=True, |
| 67 | + ) |
| 68 | + |
| 69 | + with open(f"{pvnet_dir_dir}/{MODEL_CONFIG_NAME}") as cfg: |
| 70 | + pvnet_model_config = yaml.load(cfg, Loader=yaml.FullLoader) |
| 71 | + |
| 72 | + |
| 73 | +# Get rid of the optimiser - we don't store this anymore |
| 74 | +del model_config["optimizer"] |
| 75 | + |
| 76 | +# Rename the top level model |
| 77 | +if model_config["_target_"]=="pvnet_summation.models.flat_model.FlatModel": |
| 78 | + model_config["_target_"] = "pvnet_summation.models.dense_model.DenseModel" |
| 79 | +else: |
| 80 | + raise Exception("Unknown model: " + model_config["_target_"]) |
| 81 | + |
| 82 | +# Models which used this setting are not supported any more |
| 83 | +if model_config["relative_scale_pvnet_outputs"]: |
| 84 | + raise Exception("Models with `relative_scale_pvnet_outputs=True` are no longer supported") |
| 85 | +else: |
| 86 | + del model_config["relative_scale_pvnet_outputs"] |
| 87 | + |
| 88 | + |
| 89 | +model_config["num_input_locations"] = model_config.pop("num_locations") |
| 90 | + |
| 91 | +# Re-find the model components in the new PVNet package structure |
| 92 | +model_config["output_network"]["_target_"] = ( |
| 93 | + model_config["output_network"]["_target_"] |
| 94 | + .replace("multimodal", "late_fusion") |
| 95 | + .replace("ResFCNet2", "ResFCNet") |
| 96 | +) |
| 97 | + |
| 98 | +# Add entries from the PVNet model which are now required in the summation model |
| 99 | +model_config["history_minutes"] = pvnet_model_config["history_minutes"] |
| 100 | +model_config["forecast_minutes"] = pvnet_model_config["forecast_minutes"] |
| 101 | +model_config["interval_minutes"] = pvnet_model_config.get("interval_minutes", 30) |
| 102 | +model_config["input_quantiles"] = pvnet_model_config["output_quantiles"] |
| 103 | + |
| 104 | +# Save the model config |
| 105 | +with open(f"{local_dir}/{MODEL_CONFIG_NAME}", "w") as f: |
| 106 | + yaml.dump(model_config, f, sort_keys=False, default_flow_style=False) |
| 107 | + |
| 108 | +# Create a datamodule |
| 109 | +with open(f"{local_dir}/{DATAMODULE_CONFIG_NAME}", "w") as f: |
| 110 | + datamodule = {"pvnet_model": {"model_id": pvnet_model_id, "revision": pvnet_revision}} |
| 111 | + yaml.dump(datamodule, f, sort_keys=False, default_flow_style=False) |
| 112 | + |
| 113 | +# Resave the model weights as safetensors and remove the PVNet weights which we no longer need |
| 114 | +state_dict = torch.load(f"{local_dir}/pytorch_model.bin", map_location="cpu", weights_only=True) |
| 115 | +new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("pvnet_model")} |
| 116 | +save_file(new_state_dict, f"{local_dir}/{PYTORCH_WEIGHTS_NAME}") |
| 117 | +os.remove(f"{local_dir}/pytorch_model.bin") |
| 118 | + |
| 119 | +# Add a note to the model card to say the model has been migrated |
| 120 | +with open(f"{local_dir}/{MODEL_CARD_NAME}", "a") as f: |
| 121 | + current_date = datetime.date.today().strftime("%Y-%m-%d") |
| 122 | + summation_version = version("pvnet_summation") |
| 123 | + f.write( |
| 124 | + f"\n\n---\n**Migration Note**: This model was migrated on {current_date} " |
| 125 | + f"to pvnet-summation version {summation_version}\n" |
| 126 | + ) |
| 127 | + |
| 128 | +# ------------------------------------------ |
| 129 | +# CHECKS |
| 130 | + |
| 131 | +# Check the model can be loaded |
| 132 | +model = BaseModel.from_pretrained(model_id=local_dir, revision=None) |
| 133 | + |
| 134 | +print("Model checkpoint successfully migrated") |
| 135 | + |
| 136 | +# ------------------------------------------ |
| 137 | +# UPLOAD TO HUGGINGFACE |
| 138 | + |
| 139 | +if upload: |
| 140 | + print("Uploading migrated model to huggingface") |
| 141 | + |
| 142 | + operations = [] |
| 143 | + for file in [MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATAMODULE_CONFIG_NAME]: |
| 144 | + # Stage modified files for upload |
| 145 | + operations.append( |
| 146 | + CommitOperationAdd( |
| 147 | + path_in_repo=file, # Name of the file in the repo |
| 148 | + path_or_fileobj=f"{local_dir}/{file}", # Local path to the file |
| 149 | + ), |
| 150 | + ) |
| 151 | + |
| 152 | + operations.append( |
| 153 | + # Remove old pytorch weights file |
| 154 | + CommitOperationDelete(path_in_repo="pytorch_model.bin") |
| 155 | + ) |
| 156 | + |
| 157 | + commit_info = api.create_commit( |
| 158 | + repo_id=repo_id, |
| 159 | + operations=operations, |
| 160 | + commit_message=f"Migrate model to pvnet-summation version {summation_version}", |
| 161 | + ) |
| 162 | + |
| 163 | + # Print the most recent commit hash |
| 164 | + c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0] |
| 165 | + |
| 166 | + print( |
| 167 | + f"\nThe latest commit is now: \n" |
| 168 | + f" date: {c.created_at} \n" |
| 169 | + f" commit hash: {c.commit_id}\n" |
| 170 | + f" by: {c.authors}\n" |
| 171 | + f" title: {c.title}\n" |
| 172 | + ) |
0 commit comments