Skip to content

Commit f4493a3

Browse files
authored
[callback] Add MFU logging support (#6434)
1 parent 58a8ffb commit f4493a3

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

swift/plugin/callback.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import time
3+
24
import numpy as np
5+
import torch
36
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
47

58
from swift.utils import get_logger
@@ -28,6 +31,134 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer
2831
control.should_training_stop = True
2932

3033

34+
class PerfMetricsLogCallback(TrainerCallback):
35+
"""An callback for perf metrics (MFU etc) log implementation"""
36+
37+
def __init__(self):
38+
self.device_tflops = None
39+
self.elapsed = 0.0
40+
self.step_start_time = None
41+
42+
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
43+
from swift.utils import get_current_device, get_device_count, get_env_args
44+
45+
# Top priority. Specify by ENV
46+
tflops = get_env_args('DEVICE_TFLOPS', int, None)
47+
device_count = max(get_device_count(), 1)
48+
if tflops is not None:
49+
logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]")
50+
else:
51+
# Run a estimating test.
52+
dtype = kwargs.get('model').dtype
53+
device = torch.device(get_current_device())
54+
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
55+
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
56+
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
57+
# TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables.
58+
59+
self.device_tflops = tflops * device_count
60+
61+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
62+
self.step_start_time = time.time()
63+
64+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
65+
self.elapsed += time.time() - self.step_start_time
66+
67+
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
68+
total_flos = getattr(state, 'total_flos', 0)
69+
actual_flops = total_flos / self.elapsed
70+
theoretical_max_flops = self.device_tflops * 1e12
71+
mfu = actual_flops / theoretical_max_flops
72+
logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]')
73+
logs['MFU'] = round(mfu, 6)
74+
75+
@staticmethod
76+
def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192):
77+
from swift.utils.torch_utils import empty_cache
78+
79+
def device_synchronize(sync_device):
80+
if backend == 'cuda':
81+
torch.cuda.synchronize(sync_device)
82+
elif backend == 'npu':
83+
torch.npu.synchronize(sync_device)
84+
elif backend == 'cpu':
85+
torch.cpu.synchronize(sync_device)
86+
87+
# Set matrix dimension
88+
shape = (dim, dim)
89+
backend = device.type
90+
if backend == 'npu':
91+
import torch_npu
92+
93+
# Initialize matrices
94+
a = torch.randn(*shape, device=device, dtype=dtype)
95+
b = torch.randn(*shape, device=device, dtype=dtype)
96+
97+
# Warm-up
98+
for _ in range(5):
99+
c = torch.matmul(a, b)
100+
device_synchronize(device)
101+
102+
# Run benchmark test
103+
start = time.time()
104+
for _ in range(repeats):
105+
c = torch.matmul(a, b)
106+
device_synchronize(device)
107+
end = time.time()
108+
total_time = end - start
109+
avg_time = total_time / repeats
110+
111+
# Adjust repeat count and retest if test duration is too short
112+
if total_time < 3:
113+
repeats = int(6 / avg_time)
114+
start = time.time()
115+
for _ in range(repeats):
116+
c = torch.matmul(a, b)
117+
device_synchronize(device)
118+
end = time.time()
119+
total_time = end - start
120+
avg_time = total_time / repeats
121+
122+
del a, b, c
123+
empty_cache()
124+
125+
tflops = (2 * dim**3 / avg_time) / 1e12
126+
logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS')
127+
return tflops
128+
129+
@staticmethod
130+
def _retrieve_flops_from_map(device):
131+
"""Retrieve theoretical FLOPS from Map. """
132+
133+
device_name = device.get_device_name()
134+
flops = None
135+
for name, value in device_flops_map.items():
136+
if name in device_name:
137+
flops = value
138+
break
139+
140+
return flops
141+
142+
143+
device_flops_map = {
144+
'GB200': 2.5e15,
145+
'B200': 2.25e15,
146+
'MI300X': 1336e12,
147+
'H100': 312e12,
148+
'H800': 312e12,
149+
'H200': 989e12,
150+
'A100': 312e12,
151+
'A800': 312e12,
152+
'L40S': 362.05e12,
153+
'L40': 181.05e12,
154+
'A40': 149.7e12,
155+
'L20': 119.5e12,
156+
'H20': 148e12,
157+
'910B': 354e12,
158+
'Ascend910': 354e12,
159+
'RTX 3070 Ti': 21.75e12
160+
}
161+
31162
extra_callbacks = []
32163
# This example shows a simple example of EarlyStop Callback, uncomment this to use
33164
# extra_callbacks = [EarlyStopCallback()]

0 commit comments

Comments
 (0)