Skip to content

Conversation

@y2logic
Copy link
Contributor

@y2logic y2logic commented Nov 5, 2025

PR type

  • New Feature

PR information

Implement a callback plugin to add MFU metrics to log.
Related to issue Add MFU (Model FLOPs Utilization) logging support #5791

Write the detail information belongs to this PR.

  • Specify device TFLOPS by setting ENV DEVICE_TFLOPS. (Higher priority)
  • Provide a quick test for estimating device compute capability.
  • Provide a fallback strategy based on lookup tables, more data needed. (Community may have more reliable data)

Experiment results

[INFO:swift] model_parameter_info: PeftModelForCausalLM: 498.4319M Params (4.3991M Trainable [0.8826%]), 0.0000M Buffers.
[INFO:swift] Setting DEVICE_TFLOPS: None. You can adjust this hyperparameter through the environment variable: `DEVICE_TFLOPS`.
[INFO:swift] Estimating device TFLOPS baseline. Device: [cuda:0] dtype: [torch.float16]
[设备 cuda:0] 测试总耗时:8.3362s,平均耗时: 0.1389 s,dtype:torch.float16,性能: 7.9137 TFLOPS
[INFO:swift] Estimate test finished. [7.9137353585254075 TFLOPS] Device count: [1]
[INFO:swift] use_reentrant: True
[INFO:swift] The logging file will be saved in: /home/jovyan/y2logic/config/output/v8-20251105-011248/logging.jsonl
Train:   0%|                                                                                                   | 0/8 [00:00<?, ?it/s][INFO:swift] use_logits_to_keep: True
/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
{'loss': 1.92887628, 'grad_norm': 1.0251199, 'learning_rate': 1e-05, 'token_acc': 0.55102721, 'epoch': 0.13, 'MFU': 0.200113, 'global_step/max_steps': '1/8', 'percentage': '12.50%', 'elapsed_time': '23s', 'remaining_time': '2m 44s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.042512}
{'loss': 1.80279827, 'grad_norm': 0.94247454, 'learning_rate': 9.5e-06, 'token_acc': 0.5656051, 'epoch': 0.26, 'MFU': 0.202701, 'global_step/max_steps': '2/8', 'percentage': '25.00%', 'elapsed_time': '44s', 'remaining_time': '2m 14s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.044644}
{'loss': 1.72881317, 'grad_norm': 0.92661977, 'learning_rate': 8.12e-06, 'token_acc': 0.57771261, 'epoch': 0.38, 'MFU': 0.20362, 'global_step/max_steps': '3/8', 'percentage': '37.50%', 'elapsed_time': '1m 7s', 'remaining_time': '1m 51s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.044769}
{'loss': 1.74326658, 'grad_norm': 0.93214059, 'learning_rate': 6.11e-06, 'token_acc': 0.57969724, 'epoch': 0.51, 'MFU': 0.204517, 'global_step/max_steps': '4/8', 'percentage': '50.00%', 'elapsed_time': '1m 28s', 'remaining_time': '1m 28s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.045414}
{'loss': 1.88908231, 'grad_norm': 0.96761024, 'learning_rate': 3.89e-06, 'token_acc': 0.5600538, 'epoch': 0.64, 'MFU': 0.204593, 'global_step/max_steps': '5/8', 'percentage': '62.50%', 'elapsed_time': '1m 49s', 'remaining_time': '1m 5s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.045703}
{'loss': 1.7942878, 'grad_norm': 0.96311468, 'learning_rate': 1.88e-06, 'token_acc': 0.5789354, 'epoch': 0.77, 'MFU': 0.204409, 'global_step/max_steps': '6/8', 'percentage': '75.00%', 'elapsed_time': '2m 13s', 'remaining_time': '44s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.044814}
{'loss': 1.61809587, 'grad_norm': 0.84860581, 'learning_rate': 5e-07, 'token_acc': 0.60130132, 'epoch': 0.9, 'MFU': 0.203771, 'global_step/max_steps': '7/8', 'percentage': '87.50%', 'elapsed_time': '2m 37s', 'remaining_time': '22s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.044347}
{'loss': 1.79473865, 'grad_norm': 0.96563059, 'learning_rate': 0.0, 'token_acc': 0.57638475, 'epoch': 1.0, 'MFU': 0.203993, 'global_step/max_steps': '8/8', 'percentage': '100.00%', 'elapsed_time': '2m 56s', 'remaining_time': '0s', 'memory(GiB)': 11.15, 'train_speed(iter/s)': 0.045302}

Paste your experiment result here(if needed).

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @y2logic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the training pipeline by integrating Model FLOPs Utilization (MFU) logging. This feature aims to provide developers and researchers with critical performance metrics, enabling them to better understand and optimize the efficiency of their hardware during model training. The implementation offers a flexible approach to determining theoretical device TFLOPS, either through direct user input via an environment variable or through an intelligent automatic estimation process, ensuring robust MFU reporting across different setups.

Highlights

  • Introduction of PerfMetricsLogCallback: A new TrainerCallback named PerfMetricsLogCallback has been added to automatically calculate and log Model FLOPs Utilization (MFU) during the training process, providing insights into hardware efficiency.
  • Flexible TFLOPS Configuration: Users can now specify the theoretical maximum TFLOPS of their device using the DEVICE_TFLOPS environment variable, which takes precedence over automatic estimation.
  • Automatic TFLOPS Estimation: If the DEVICE_TFLOPS environment variable is not explicitly set, the system will automatically estimate the device's TFLOPS by running a quick matrix multiplication benchmark tailored to the device and data type.
  • Device TFLOPS Lookup Table: A device_flops_map has been included, containing theoretical FLOPS values for various common GPU devices (e.g., H100, A100, RTX 3070 Ti, Ascend910), intended for future fallback strategies or reference.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a PerfMetricsLogCallback to compute and log Model FLOPs Utilization (MFU), which is a valuable addition for performance monitoring. The implementation is well-structured, offering device TFLOPS estimation and configuration via environment variables. My review focuses on improving robustness, maintainability, and consistency. I've identified a bug in a helper function, suggested refactoring to reduce code duplication, recommended leveraging existing utilities, and pointed out some dead code. Addressing these points will enhance the quality and reliability of the new callback.

@y2logic y2logic marked this pull request as draft November 5, 2025 01:47
@y2logic y2logic marked this pull request as ready for review November 5, 2025 02:08
@Jintao-Huang
Copy link
Collaborator

Thank you for the PR. Please change the Chinese comments to English comments.

'RTX 3070 Ti': 21.75e12
}

extra_callbacks = [PerfMetricsLogCallback()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not enable by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Is a switch argument required?

@y2logic
Copy link
Contributor Author

y2logic commented Nov 21, 2025

Thank you for the PR. Please change the Chinese comments to English comments.

All comments are now in English, and I have left the extra_callbacks untouched.

@Jintao-Huang
Copy link
Collaborator

Use Python < 3.12 and run the code below

pip install pre-commit
pre-commit run --all-files

@y2logic
Copy link
Contributor Author

y2logic commented Nov 21, 2025

Use Python < 3.12 and run the code below

pip install pre-commit
pre-commit run --all-files

IDE formatting problem. It has been corrected.

@Jintao-Huang
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a PerfMetricsLogCallback to compute and log Model FLOPs Utilization (MFU), which is a valuable metric for performance monitoring. The overall approach is sound, but there are several areas for improvement regarding correctness, robustness, and efficiency. My review highlights a critical bug in the TFLOPS estimation for CPU environments, points out that the efficient lookup-table-based TFLOPS retrieval is implemented but unused and buggy, and suggests other improvements like handling potential division-by-zero errors and using more appropriate data types. Addressing these points will make the new callback more robust and reliable.

Comment on lines +84 to +85
elif backend == 'cpu':
torch.cpu.synchronize(sync_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The helper function device_synchronize has a bug when the backend is 'cpu'. torch.cpu.synchronize does not exist in PyTorch, and calling it will cause an AttributeError. CPU operations are typically synchronous, so an explicit synchronization is not needed.

Suggested change
elif backend == 'cpu':
torch.cpu.synchronize(sync_device)
elif backend == 'cpu':
pass # CPU operations are synchronous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function torch.cpu.synchronize exists and have been tested.

Comment on lines +130 to +140
def _retrieve_flops_from_map(device):
"""Retrieve theoretical FLOPS from Map. """

device_name = device.get_device_name()
flops = None
for name, value in device_flops_map.items():
if name in device_name:
flops = value
break

return flops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This method has a few issues and is currently not used:

  1. device.get_device_name() will raise an AttributeError as torch.device objects do not have this method. You should use torch.cuda.get_device_name(device) for CUDA devices and handle other device types gracefully.
  2. The values in device_flops_map are in FLOPS, but this function should return TFLOPS to be consistent with other parts of the callback. The result should be divided by 1e12.
  3. This function is currently unused. It should be called in on_init_end to provide an efficient way to determine TFLOPS.
    def _retrieve_flops_from_map(device):
        """Retrieve theoretical FLOPS from Map and convert to TFLOPS."""
        if device.type != 'cuda':
            # Add other supported device types like 'npu' if needed
            return None

        device_name = torch.cuda.get_device_name(device)
        flops = None
        for name, value in device_flops_map.items():
            if name in device_name:
                flops = value
                break

        if flops is not None:
            return flops / 1e12
        return None

from swift.utils import get_current_device, get_device_count, get_env_args

# Top priority. Specify by ENV
tflops = get_env_args('DEVICE_TFLOPS', int, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEVICE_TFLOPS is parsed as an integer, which might be too restrictive as TFLOPS values are often floating-point numbers (e.g., from the estimation or lookup table). Using float would be more appropriate and consistent.

Suggested change
tflops = get_env_args('DEVICE_TFLOPS', int, None)
tflops = get_env_args('DEVICE_TFLOPS', float, None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems float is correct.

Comment on lines +50 to +56
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _retrieve_flops_from_map function provides an efficient way to get device TFLOPS from a lookup table. It should be called as a fallback before running the performance estimation test. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

Suggested change
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
else:
device = torch.device(get_current_device())
tflops = self._retrieve_flops_from_map(device)
if tflops is not None:
device_name = torch.cuda.get_device_name(device) if device.type == 'cuda' else str(device)
logger.info(f'Retrieved TFLOPS from lookup table for {device_name}: {tflops} TFLOPS')
else:
# Run an estimating test.
dtype = kwargs.get('model').dtype
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')


def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0)
actual_flops = total_flos / self.elapsed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential ZeroDivisionError here if self.elapsed is 0. This could happen if on_log is called before on_step_end has completed. It's safer to add a guard for this case.

Suggested change
actual_flops = total_flos / self.elapsed
actual_flops = total_flos / self.elapsed if self.elapsed > 0 else 0

return tflops

@staticmethod
def _retrieve_flops_from_map(device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this function is not being used. Just curious, what's the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More data should be collected, and hopefully the function _retrieve_flops_from_map can become a more efficient and accurate way to obtain device TFLOPS. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

def _retrieve_flops_from_map(device):
"""Retrieve theoretical FLOPS from Map. """

device_name = device.get_device_name()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the get_device_name function?

@Jintao-Huang
Copy link
Collaborator

Hi, have all the changes been committed? I'm ready to merge.

@y2logic
Copy link
Contributor Author

y2logic commented Nov 24, 2025

Hi, have all the changes been committed? I'm ready to merge.

Yes. All changes committed.

@Jintao-Huang Jintao-Huang merged commit f4493a3 into modelscope:main Nov 24, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants