diff --git a/dist_run.py b/dist_run.py index bc3badbc4..e077ea9ca 100644 --- a/dist_run.py +++ b/dist_run.py @@ -22,10 +22,10 @@ from torchchat.distributed.logging_utils import SingletonLogger # TODO - these are not distributed specific, consider moving to new package -from torchchat.distributed.safetensor_utils import ( +from torchchat.distributed.checkpoint_utils import ( get_hf_config_file, - get_hf_weight_map_and_path, - load_safetensor_weights, + load_weights_from_hf_format, + load_weights_from_torchchat_format, ) from torchchat.distributed.utils import ( bytes_to_readable, @@ -110,26 +110,33 @@ def _build_chat_tokenizer( return tokenizer -def _load_model_weights(stage_module, distribution, device, model_config): +def _load_model_weights( + stage_module: torch.nn.Module, + distribution: str, + device: torch.device, + model_config: ModelArgs, + chpt_from: str, +): """Load the weights from the safetensor file(s) into the model stage. Model config is needed b/c we permute wq and wk weights based on attn heads. - """ - weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) - - num_loaded_weights, num_missing_weights = load_safetensor_weights( - stage_module, - weight_map, - weight_path, - key_map, - device, - model_config=model_config, - ) - logger.info( - f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" - ) - if num_missing_weights > 0: - raise ValueError(f"Missing {num_missing_weights} weights") + Args: + stage_module (torch.nn.Module): The model stage to load the weights into. + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". + device (torch.device): The device to load the weights onto. + model_config (ModelArgs): The model config. + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". + """ + if chpt_from == "hf": + # This format stands for: index file + multiple binary files + load_weights_from_hf_format(stage_module, distribution, device, model_config) + elif chpt_from == "torchchat": + # This format stands for: + # single binary file, OR + # multiple binary files without index files. + load_weights_from_torchchat_format(stage_module, distribution, device, model_config) + else: + raise ValueError(f"Unknown checkpoint format: {chpt_from}") def _encode_strings( @@ -286,7 +293,7 @@ def main(args): logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] - logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") + logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") # Model-level config model_config = ModelArgs.from_name(distribution) @@ -348,7 +355,7 @@ def main(args): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - _load_model_weights(model, distribution, device=device, model_config=config) + _load_model_weights(model, distribution, device, config, args.chpt_from) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -579,6 +586,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: default=False, help="Whether to decode token into string in flight", ) + parser.add_argument( + "--chpt-from", + type=str, + default="hf", # TODO: change to torchchat once we support it well + help="Checkpoint format to load from", + choices=["hf", "torchchat"], + ) args = parser.parse_args() main(args) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 3abed339a..02b1545d0 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -335,11 +335,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model: return model -def _load_model_default(builder_args: BuilderArgs) -> Model: - assert not builder_args.gguf_path - - model: Model = _init_model_on_meta_device(builder_args) - +def _load_checkpoint(builder_args: BuilderArgs): if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") meta_checkpoint = torch.load( @@ -377,6 +373,16 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: mmap=True, weights_only=True, ) + return checkpoint + + +def _load_model_default(builder_args: BuilderArgs) -> Model: + assert not builder_args.gguf_path + + model: Model = _init_model_on_meta_device(builder_args) + + # Load checkpoint from filesystem + checkpoint = _load_checkpoint(builder_args) if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] diff --git a/torchchat/distributed/safetensor_utils.py b/torchchat/distributed/checkpoint_utils.py similarity index 65% rename from torchchat/distributed/safetensor_utils.py rename to torchchat/distributed/checkpoint_utils.py index 80ae6b585..cf3206e4e 100644 --- a/torchchat/distributed/safetensor_utils.py +++ b/torchchat/distributed/checkpoint_utils.py @@ -11,10 +11,12 @@ import os import json from torch.nn import Module -from typing import Dict, Tuple, Set, Optional +from typing import Any, Dict, Tuple, Set, Optional +from pathlib import Path from torch.distributed._tensor import DTensor from torchchat.distributed.dtensor_utils import convert_to_dtensor +from torchchat.cli.builder import BuilderArgs, _load_checkpoint _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" @@ -165,9 +167,11 @@ def load_safetensor_weights( Returns: Tuple[int, int]: Number of updated weights and number of missing weights. """ - stage_state_dict, weight_map = prepare_state_dict( - stage_module, weight_map, purge_model_prefix - ) + stage_state_dict = stage_module.state_dict() + if purge_model_prefix: + stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") + weight_map = purge_fqn_prefix(weight_map, "model.") + needed_files = get_needed_files(stage_state_dict, weight_map) updated_states: Set[str] = set() @@ -175,17 +179,15 @@ def load_safetensor_weights( full_path = os.path.join(file_location, file) # logger.info(f"Loading checkpoint file: {full_path}") try: - checkpoint = load_checkpoint(full_path, "cpu") # device) + checkpoint = load_safetensor_file(full_path, "cpu") # device) update_state_dict( stage_state_dict, checkpoint, - weight_map, - new_to_old_keymap, - file, - updated_states, device, - model_config, + model_config=model_config, + new_to_old_keymap=new_to_old_keymap, + updated_states=updated_states, ) except FileNotFoundError: logger.error(f"File not found: {full_path}") @@ -215,14 +217,14 @@ def load_safetensor_weights( return len(updated_states), len(missing_keys) -def prepare_state_dict( - module: Module, weight_map: Dict[str, str], purge_model_prefix: bool -) -> Dict[str, torch.Tensor]: - state_dict = module.state_dict() - if purge_model_prefix: - state_dict = {k.removeprefix("model."): v for k, v in state_dict.items()} - weight_map = {k.removeprefix("model."): v for k, v in weight_map.items()} - return state_dict, weight_map +# TODO: clean this up together with `purge_fqn_prefix` when we switch +# from creating Transformer to creating model +def purge_fqn_prefix( + any_dict: Dict[str, Any], + prefix: str, +) -> Dict[str, Any]: + """Remove a prefix from all keys in a dictionary.""" + return {k.removeprefix(prefix): v for k, v in any_dict.items()} def get_needed_files( @@ -242,7 +244,7 @@ def get_needed_files( return needed_files -def load_checkpoint(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]: +def load_safetensor_file(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]: tensors = {} with safe_open(full_path, framework="pt", device=device) as f: for k in f.keys(): @@ -264,64 +266,66 @@ def permute_weight_to_attn_heads(w, n_heads, head_dim, model_dim): def update_state_dict( state_dict: Dict[str, torch.Tensor], checkpoint: Dict[str, torch.Tensor], - weight_map: Dict[str, str], - new_to_old_keymap: Dict[str, str], - file: str, - updated_states: Set[str], device: torch.device, model_config: Optional[Dict] = None, + new_to_old_keymap: Optional[Dict[str, str]] = None, + updated_states: Optional[Set[str]]= None, ): - count_dtensors_loaded = 0 + """ + Update the state dict with the checkpoint tensors. + Note: + - For HF format, `new_to_old_keymap` is a mapping from the new key to the old + key. + - For torchchat format, `new_to_old_keymap` is None (because FQN conversion + has been doen by torchchat download script). + """ # for handling attn head permuting num_heads = model_config.n_heads dim = model_config.dim num_local_heads = model_config.n_local_heads head_dim = model_config.head_dim - for param, file_with_param in weight_map.items(): - if file_with_param == file and param in state_dict: + for param in state_dict.keys(): + if new_to_old_keymap is not None: + # TODO: clean the following manual prefix together with + # `purge_fqn_prefix` when we switch from creating Transformer to + # creating model model_param = ( "output.weight" if param == "output.weight" else f"model.{param}" ) - old_param = new_to_old_keymap.get(model_param) - - if old_param not in checkpoint: - logger.warning(f"Missing {old_param} in checkpoint") - continue - - checkpoint_tensor = checkpoint[old_param] - model_tensor = state_dict[param] - - if "wq" in param: - checkpoint_tensor = permute_weight_to_attn_heads( - checkpoint_tensor, num_heads, head_dim, dim - ) - elif "wk" in param: - checkpoint_tensor = permute_weight_to_attn_heads( - checkpoint_tensor, num_local_heads, head_dim, dim - ) - - # Move checkpoint tensor to desired device - checkpoint_tensor = checkpoint_tensor.to(device) - - # here we need to check if the tensor is a DTensor and if so, adjust the - # shape and placement to match the model DTensor. - if isinstance(model_tensor, DTensor): - state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor) - count_dtensors_loaded += 1 - else: - # regular tensor, just update directly - state_dict[param] = checkpoint_tensor - - # ensure matching dtypes - state_dict[param] = state_dict[param].to(checkpoint_tensor.dtype) - - assert state_dict[param].dtype == checkpoint_tensor.dtype - - # log_tensor_info(param, state_dict[param]) - # logger.info(f"Loaded {param} from {file}") + old_param = new_to_old_keymap[model_param] + else: + old_param = param + + if old_param not in checkpoint: + # Maybe this param is in other files + continue + + checkpoint_tensor = checkpoint[old_param] + model_tensor = state_dict[param] + + if "wq" in param: + checkpoint_tensor = permute_weight_to_attn_heads( + checkpoint_tensor, num_heads, head_dim, dim + ) + elif "wk" in param: + checkpoint_tensor = permute_weight_to_attn_heads( + checkpoint_tensor, num_local_heads, head_dim, dim + ) + + # Move checkpoint tensor to desired device + checkpoint_tensor = checkpoint_tensor.to(device) + + # here we need to check if the tensor is a DTensor and if so, adjust the + # shape and placement to match the model DTensor. + if isinstance(model_tensor, DTensor): + checkpoint_tensor = convert_to_dtensor(checkpoint_tensor, model_tensor) + + # Update model state dict with checkpoint tensor + state_dict[param] = checkpoint_tensor + + if updated_states is not None: updated_states.add(param) - # logger.info(f"Count of loaded DTensors: {count_dtensors_loaded}") def format_tensor_info(tensor: torch.Tensor) -> str: @@ -366,3 +370,83 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]): else: logger.info("Fully updated state dict.") logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") + + +def load_weights_from_hf_format(stage_module, distribution, device, model_config): + """ + Load the weights from Hugging Face format (index file + multiple safetensor + files), and fill into `stage_module`. Model config is needed b/c we permute + wq and wk weights based on attn heads. + """ + + weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) + + num_loaded_weights, num_missing_weights = load_safetensor_weights( + stage_module, + weight_map, + weight_path, + key_map, + device, + model_config=model_config, + ) + logger.info( + f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" + ) + if num_missing_weights > 0: + raise ValueError(f"Missing {num_missing_weights} weights") + + +# HACK: assuming single file for torchchat's converted checkpoints. We should +# remove this after converging to torchchat's model building process. +# In particular, +# builder_args = BuilderArgs.from_args(args) +# will tell us if there is a single file or a directory. +TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True + +def load_weights_from_torchchat_format(stage_module, distribution, device, model_config): + """ + Load the weights from torchchat format (single binary file), and fill into + `stage_module`. Model config is needed b/c we permute wq and wk weights + based on attn heads. + """ + stage_state_dict = stage_module.state_dict() + # TODO: clean this up together with `purge_fqn_prefix` when we switch + stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") + + # Load checkpoint from torchchat cache + default_cache_dir = Path( + os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") + ).expanduser() + # Distribution is like "meta-llama/Meta-Llama-3-8B-Instruct" + # Join it with the default cache dir to get the checkpoint dir + checkpoint_dir = default_cache_dir / distribution + # Provide path in single-file case, provide dir in multi-file case. See + # `_load_checkpoint`. + if TORCHCHCAT_SINGLE_FILE_CHECKPOINT: + checkpoint_path = checkpoint_dir / "model.pth" + checkpoint_dir = None + else: + checkpoint_path = None + # First, construct BuilderArgs + args_dict = { + "device": device, + "checkpoint_dir": checkpoint_dir, + "checkpoint_path": checkpoint_path, + } + builder_args = BuilderArgs(**args_dict) + # Then, load the checkpoint using torchchat util + checkpoint = _load_checkpoint(builder_args) + + updated_states: Set[str] = set() + # This step converts full tensor into DTensor + update_state_dict( + stage_state_dict, + checkpoint, + device, + model_config=model_config, + updated_states=updated_states, + ) + + # Fill state dict into stage module + stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) + logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") diff --git a/torchchat/distributed/dtensor_utils.py b/torchchat/distributed/dtensor_utils.py index 1a6704caa..d11f2a751 100644 --- a/torchchat/distributed/dtensor_utils.py +++ b/torchchat/distributed/dtensor_utils.py @@ -1,77 +1,69 @@ import torch -from torch.distributed._tensor import DTensor, Shard, Replicate - +from torch.distributed import DeviceMesh +from torch.distributed._tensor import DTensor, Shard, Replicate, Placement +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from collections import defaultdict +from typing import Optional, Sequence from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() -def convert_to_dtensor(weight_tensor, dtensor_template): - """Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it""" - - if weight_tensor.shape != dtensor_template.shape: +def convert_to_dtensor( + full_tensor: torch.Tensor, + dtensor_template: DTensor, +) -> DTensor: + """ + Converts a full tensor to a DTensor with the same placements as the given + DTensor template. + """ + if full_tensor.shape != dtensor_template.shape: raise ValueError( - f"Shape mismatch: weight tensor shape {weight_tensor.shape} " + f"Shape mismatch: weight tensor shape {full_tensor.shape} " f"doesn't match DTensor shape {dtensor_template.shape}" ) - placements = dtensor_template.placements - mesh = dtensor_template.device_mesh - mesh_dims = mesh.ndim - - for placement in placements: - if isinstance(placement, Shard): - shard_dim = placement.dim - - if shard_dim >= weight_tensor.dim(): - raise ValueError( - f"Shard dimension {shard_dim} is out of range for tensor with {weight_tensor.dim()} dimensions." - ) - - num_shards = mesh.size( - 0 - ) # Assuming sharding is always along the first mesh dimension - shard_size = weight_tensor.size(shard_dim) // num_shards - shard_index = mesh.get_coordinate()[0] - - start_idx = shard_index * shard_size - end_idx = start_idx + shard_size - - slice_list = [slice(None)] * weight_tensor.dim() - slice_list[shard_dim] = slice(start_idx, end_idx) - weight_tensor = weight_tensor[tuple(slice_list)] - - elif isinstance(placement, Replicate): - continue - else: - raise ValueError(f"Unsupported placement type: {type(placement)}") - - new_dtensor = DTensor.from_local(weight_tensor, mesh, placements) - + new_dtensor = shard( + full_tensor, + dtensor_template.placements, + dtensor_template.device_mesh + ) return new_dtensor -def inspect_dtensor_sharding(dtensor): - """hepful debug util for inspecting DTensor sharding""" - if not is_dtensor(dtensor): - logger.info(f"This tensor {dtensor} is not a DTensor") - return - - placements = dtensor.placements - logger.info(f"DTensor shape: {dtensor.shape}") - logger.info(f"Number of dimensions: {len(placements)}") - - for dim, placement in enumerate(placements): - logger.info(f"Dimension {dim}:") - logger.info(f" Placement type: {placement.type}") - if placement.type == "shard": - logger.info(f" Sharding spec: {placement.sharding_spec}") - elif placement.type == "replicate": - logger.info(" Replicated across devices") - else: - logger.info(f" Other placement type: {placement.type}") - - logger.info(f"Device mesh shape: {dtensor.device_mesh.shape}") - logger.info(f"Device mesh devices: {dtensor.device_mesh.device_type}") +def shard( + full_tensor: torch.Tensor, + placements: Sequence[Placement], + device_mesh: Optional[DeviceMesh] = None, +) -> DTensor: + """ + Shards a full tensor based on indicated placements, and returns a + DTensor containing the shard. + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + placements (Sequence[:class:`Placement`]): the placements that + describes how to place the local tensor on DeviceMesh. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + If not specified, would be retrieve from current context. + Returns: + A :class:`DTensor` object with the shard as its local tensor. + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> placements = [Shard(1)] + >>> dtensor = shard(full_tensor, placements, device_mesh) + """ + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return DTensor.from_local(local_tensor, device_mesh, placements)