Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Script to apply average merging to Hugging Face models.
# This script can work with the pretrain Python environment.
#
#
# Usage:
# python merge.py \
# --source-models /path/to/model1 /path/to/model2 \
# --source-weights 1.0 2.0 \ # Optional
# --aggregation-method average \ # Optional,
# --output-model /path/to/output_model

import argparse
Expand Down Expand Up @@ -31,6 +33,25 @@ def parse_args():
" All models should be in the same format and have compatible parameters."
),
)
p.add_argument(
"--source-weights",
type=float,
nargs="+",
default=None,
help=(
"Weights for each source model. If not provided, "
"all models will be treated equally (weight = 1)."
),
)
p.add_argument(
"--aggregation-method",
choices=["average", "sum"],
default="average",
help=(
"Method to aggregate parameters. "
"'average' will average the parameters, while 'sum' will sum them up."
),
)
p.add_argument(
"--output-model",
type=pathlib.Path,
Expand All @@ -44,7 +65,7 @@ def parse_args():
def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]:
"""
Iterate through the parameters of a model stored in a .safetensors file.

Args:
model_path (pathlib.Path): Path to the .safetensors file.

Expand All @@ -62,17 +83,34 @@ def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]:
def main():
args = parse_args()

logging.info(f"Source models: {args.source_models}")

if args.source_weights is not None:
if len(args.source_weights) != len(args.source_models):
raise ValueError("Number of source weights must match number of source models.")
else:
args.source_weights = [1.0] * len(args.source_models)

logging.info(f"Source weights: {args.source_weights}")

match args.aggregation_method:
case "average":
if any(x <= 0 for x in args.source_weights):
raise ValueError("All source weights must be positive for --aggregation-method=average.")
denominator = sum(args.source_weights)
case "sum":
denominator = 1.0
case _:
raise ValueError(f"Unknown aggregation method: {args.aggregation_method}")

logging.info(f"Aggregation method: {args.aggregation_method}, denominator: {denominator}")

# Initialize a dictionary to hold the sum of parameters
param_sums = {}
model_count = len(args.source_models)

if model_count == 0:
raise ValueError("No input models provided for merging.")

logging.info(f"Source models: {args.source_models}")

# Iterate through each model and accumulate the parameters
for model_path in args.source_models:
for model_path, weight in zip(args.source_models, args.source_weights):
if not model_path.exists():
raise FileNotFoundError(f"Model path {model_path} does not exist.")
if not model_path.is_dir():
Expand All @@ -81,42 +119,42 @@ def main():

for key, tensor in iter_params(model_path):
if key not in param_sums:
param_sums[key] = tensor
param_sums[key] = tensor * weight
else:
if param_sums[key].shape != tensor.shape:
raise ValueError(f"Shape mismatch for key '{key}': "
f"{param_sums[key].shape} vs {tensor.shape}")
param_sums[key] += tensor
# Average the parameters
param_sums[key] += tensor * weight

# Normalize the parameters
for key in param_sums:
param_sums[key] /= model_count
param_sums[key] /= denominator

logging.info("Merging completed. Saving the merged model...")
args.output_model.mkdir(parents=True, exist_ok=True)

# Copy original files other than .safetensors
for file in args.source_models[0].iterdir():
if file.suffix != ".safetensors":
shutil.copy(file, args.output_model / file.name)

# There should be `model.safetensors.index.json` file
# containing the mapping of parameter names to their destination file names.
index_file = args.output_model / "model.safetensors.index.json"
if not index_file.exists():
raise FileNotFoundError(f"Index file {index_file} does not exist.")
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]

# Check if the weight map is consistent with the parameters
if set(weight_map.keys()) != set(param_sums.keys()):
raise ValueError("Weight map keys do not match the parameter keys.")

# Make inverse mapping for saving
output_map = {k: [] for k in set(weight_map.values())}
for k, v in weight_map.items():
output_map[v].append(k)

metadata = {"format": "pt"}

# Save all parameters
Expand All @@ -125,7 +163,7 @@ def main():
output_path = args.output_model / file_name
logging.info(f" Saving parameters to {output_path}")
safetensors.torch.save_file(tensors, output_path, metadata=metadata)

logging.info("Merged model saved successfully.")


Expand Down