-
Notifications
You must be signed in to change notification settings - Fork 998
Add MFU logging support #6434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MFU logging support #6434
Changes from 6 commits
a820b3c
655eedb
05ba547
190ee31
60d2c5f
f1423a0
83b24f8
83137cd
7d8aa7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,8 @@ | ||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
| from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from swift.utils import get_logger | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -28,6 +31,134 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer | |||||||||||||||||||||||||||||||||||||||
| control.should_training_stop = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| extra_callbacks = [] | ||||||||||||||||||||||||||||||||||||||||
| class PerfMetricsLogCallback(TrainerCallback): | ||||||||||||||||||||||||||||||||||||||||
| """An callback for perf metrics (MFU etc) log implementation""" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||
| self.device_tflops = None | ||||||||||||||||||||||||||||||||||||||||
| self.elapsed = 0.0 | ||||||||||||||||||||||||||||||||||||||||
| self.step_start_time = None | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| from swift.utils import get_current_device, get_device_count, get_env_args | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Top priority. Specify by ENV | ||||||||||||||||||||||||||||||||||||||||
| tflops = get_env_args('DEVICE_TFLOPS', int, None) | ||||||||||||||||||||||||||||||||||||||||
| device_count = max(get_device_count(), 1) | ||||||||||||||||||||||||||||||||||||||||
| if tflops is not None: | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]") | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # Run a estimating test. | ||||||||||||||||||||||||||||||||||||||||
| dtype = kwargs.get('model').dtype | ||||||||||||||||||||||||||||||||||||||||
| device = torch.device(get_current_device()) | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]') | ||||||||||||||||||||||||||||||||||||||||
| tflops = self._estimate_device_tflops_by_dtype(device, dtype) | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]') | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+50
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| self.device_tflops = tflops * device_count | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| self.step_start_time = time.time() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| self.elapsed += time.time() - self.step_start_time | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): | ||||||||||||||||||||||||||||||||||||||||
| total_flos = getattr(state, 'total_flos', 0) | ||||||||||||||||||||||||||||||||||||||||
| actual_flops = total_flos / self.elapsed | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| theoretical_max_flops = self.device_tflops * 1e12 | ||||||||||||||||||||||||||||||||||||||||
| mfu = actual_flops / theoretical_max_flops | ||||||||||||||||||||||||||||||||||||||||
| logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]') | ||||||||||||||||||||||||||||||||||||||||
| logs['MFU'] = round(mfu, 6) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||
| def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192): | ||||||||||||||||||||||||||||||||||||||||
| from swift.utils.torch_utils import empty_cache | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def device_synchronize(sync_device): | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'cuda': | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize(sync_device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| torch.npu.synchronize(sync_device) | ||||||||||||||||||||||||||||||||||||||||
| elif backend == 'cpu': | ||||||||||||||||||||||||||||||||||||||||
| torch.cpu.synchronize(sync_device) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+84
to
+85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The helper function
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 默认矩阵规模 | ||||||||||||||||||||||||||||||||||||||||
| shape = (dim, dim) | ||||||||||||||||||||||||||||||||||||||||
| backend = device.type | ||||||||||||||||||||||||||||||||||||||||
| if backend == 'npu': | ||||||||||||||||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 创建矩阵 | ||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(*shape, device=device, dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||
| b = torch.randn(*shape, device=device, dtype=dtype) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 预热 | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(5): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| device_synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 进行测试 | ||||||||||||||||||||||||||||||||||||||||
| start = time.time() | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(repeats): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| device_synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| end = time.time() | ||||||||||||||||||||||||||||||||||||||||
| total_time = end - start | ||||||||||||||||||||||||||||||||||||||||
| avg_time = total_time / repeats | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # 若测试时间过短,调整循环次数并重新测试 | ||||||||||||||||||||||||||||||||||||||||
| if total_time < 3: | ||||||||||||||||||||||||||||||||||||||||
| repeats = int(6 / avg_time) | ||||||||||||||||||||||||||||||||||||||||
| start = time.time() | ||||||||||||||||||||||||||||||||||||||||
| for _ in range(repeats): | ||||||||||||||||||||||||||||||||||||||||
| c = torch.matmul(a, b) | ||||||||||||||||||||||||||||||||||||||||
| device_synchronize(device) | ||||||||||||||||||||||||||||||||||||||||
| end = time.time() | ||||||||||||||||||||||||||||||||||||||||
| total_time = end - start | ||||||||||||||||||||||||||||||||||||||||
| avg_time = total_time / repeats | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| del a, b, c | ||||||||||||||||||||||||||||||||||||||||
| empty_cache() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| tflops = (2 * dim**3 / avg_time) / 1e12 | ||||||||||||||||||||||||||||||||||||||||
| logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS') | ||||||||||||||||||||||||||||||||||||||||
| return tflops | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||
| def _retrieve_flops_from_map(device): | ||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed this function is not being used. Just curious, what's the reason?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More data should be collected, and hopefully the function |
||||||||||||||||||||||||||||||||||||||||
| """Retrieve theoretical FLOPS from Map. """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| device_name = device.get_device_name() | ||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the get_device_name function? |
||||||||||||||||||||||||||||||||||||||||
| flops = None | ||||||||||||||||||||||||||||||||||||||||
| for name, value in device_flops_map.items(): | ||||||||||||||||||||||||||||||||||||||||
| if name in device_name: | ||||||||||||||||||||||||||||||||||||||||
| flops = value | ||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return flops | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+130
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method has a few issues and is currently not used:
def _retrieve_flops_from_map(device):
"""Retrieve theoretical FLOPS from Map and convert to TFLOPS."""
if device.type != 'cuda':
# Add other supported device types like 'npu' if needed
return None
device_name = torch.cuda.get_device_name(device)
flops = None
for name, value in device_flops_map.items():
if name in device_name:
flops = value
break
if flops is not None:
return flops / 1e12
return None |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| device_flops_map = { | ||||||||||||||||||||||||||||||||||||||||
| 'GB200': 2.5e15, | ||||||||||||||||||||||||||||||||||||||||
| 'B200': 2.25e15, | ||||||||||||||||||||||||||||||||||||||||
| 'MI300X': 1336e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H100': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H800': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H200': 989e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A100': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A800': 312e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L40S': 362.05e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L40': 181.05e12, | ||||||||||||||||||||||||||||||||||||||||
| 'A40': 149.7e12, | ||||||||||||||||||||||||||||||||||||||||
| 'L20': 119.5e12, | ||||||||||||||||||||||||||||||||||||||||
| 'H20': 148e12, | ||||||||||||||||||||||||||||||||||||||||
| '910B': 354e12, | ||||||||||||||||||||||||||||||||||||||||
| 'Ascend910': 354e12, | ||||||||||||||||||||||||||||||||||||||||
| 'RTX 3070 Ti': 21.75e12 | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| extra_callbacks = [PerfMetricsLogCallback()] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
| # This example shows a simple example of EarlyStop Callback, uncomment this to use | ||||||||||||||||||||||||||||||||||||||||
| # extra_callbacks = [EarlyStopCallback()] | ||||||||||||||||||||||||||||||||||||||||
| # extra_callbacks = [EarlyStopCallback()] | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DEVICE_TFLOPSis parsed as an integer, which might be too restrictive as TFLOPS values are often floating-point numbers (e.g., from the estimation or lookup table). Usingfloatwould be more appropriate and consistent.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems float is correct.