diff --git a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py index a63c6d00..7f8e0852 100644 --- a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py +++ b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py @@ -1,9 +1,11 @@ # Script to apply average merging to Hugging Face models. # This script can work with the pretrain Python environment. -# +# # Usage: # python merge.py \ # --source-models /path/to/model1 /path/to/model2 \ +# --source-weights 1.0 2.0 \ # Optional +# --aggregation-method average \ # Optional, # --output-model /path/to/output_model import argparse @@ -31,6 +33,25 @@ def parse_args(): " All models should be in the same format and have compatible parameters." ), ) + p.add_argument( + "--source-weights", + type=float, + nargs="+", + default=None, + help=( + "Weights for each source model. If not provided, " + "all models will be treated equally (weight = 1)." + ), + ) + p.add_argument( + "--aggregation-method", + choices=["average", "sum"], + default="average", + help=( + "Method to aggregate parameters. " + "'average' will average the parameters, while 'sum' will sum them up." + ), + ) p.add_argument( "--output-model", type=pathlib.Path, @@ -44,7 +65,7 @@ def parse_args(): def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]: """ Iterate through the parameters of a model stored in a .safetensors file. - + Args: model_path (pathlib.Path): Path to the .safetensors file. @@ -62,17 +83,34 @@ def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]: def main(): args = parse_args() + logging.info(f"Source models: {args.source_models}") + + if args.source_weights is not None: + if len(args.source_weights) != len(args.source_models): + raise ValueError("Number of source weights must match number of source models.") + else: + args.source_weights = [1.0] * len(args.source_models) + + logging.info(f"Source weights: {args.source_weights}") + + match args.aggregation_method: + case "average": + if any(x <= 0 for x in args.source_weights): + raise ValueError("All source weights must be positive for --aggregation-method=average.") + denominator = sum(args.source_weights) + case "sum": + denominator = 1.0 + case _: + raise ValueError(f"Unknown aggregation method: {args.aggregation_method}") + + logging.info(f"Aggregation method: {args.aggregation_method}, denominator: {denominator}") + # Initialize a dictionary to hold the sum of parameters param_sums = {} model_count = len(args.source_models) - - if model_count == 0: - raise ValueError("No input models provided for merging.") - - logging.info(f"Source models: {args.source_models}") # Iterate through each model and accumulate the parameters - for model_path in args.source_models: + for model_path, weight in zip(args.source_models, args.source_weights): if not model_path.exists(): raise FileNotFoundError(f"Model path {model_path} does not exist.") if not model_path.is_dir(): @@ -81,25 +119,25 @@ def main(): for key, tensor in iter_params(model_path): if key not in param_sums: - param_sums[key] = tensor + param_sums[key] = tensor * weight else: if param_sums[key].shape != tensor.shape: raise ValueError(f"Shape mismatch for key '{key}': " f"{param_sums[key].shape} vs {tensor.shape}") - param_sums[key] += tensor - - # Average the parameters + param_sums[key] += tensor * weight + + # Normalize the parameters for key in param_sums: - param_sums[key] /= model_count + param_sums[key] /= denominator logging.info("Merging completed. Saving the merged model...") args.output_model.mkdir(parents=True, exist_ok=True) - + # Copy original files other than .safetensors for file in args.source_models[0].iterdir(): if file.suffix != ".safetensors": shutil.copy(file, args.output_model / file.name) - + # There should be `model.safetensors.index.json` file # containing the mapping of parameter names to their destination file names. index_file = args.output_model / "model.safetensors.index.json" @@ -107,16 +145,16 @@ def main(): raise FileNotFoundError(f"Index file {index_file} does not exist.") with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] - + # Check if the weight map is consistent with the parameters if set(weight_map.keys()) != set(param_sums.keys()): raise ValueError("Weight map keys do not match the parameter keys.") - + # Make inverse mapping for saving output_map = {k: [] for k in set(weight_map.values())} for k, v in weight_map.items(): output_map[v].append(k) - + metadata = {"format": "pt"} # Save all parameters @@ -125,7 +163,7 @@ def main(): output_path = args.output_model / file_name logging.info(f" Saving parameters to {output_path}") safetensors.torch.save_file(tensors, output_path, metadata=metadata) - + logging.info("Merged model saved successfully.")