|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +import time |
| 3 | + |
2 | 4 | import numpy as np |
| 5 | +import torch |
3 | 6 | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
4 | 7 |
|
5 | 8 | from swift.utils import get_logger |
@@ -28,6 +31,134 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer |
28 | 31 | control.should_training_stop = True |
29 | 32 |
|
30 | 33 |
|
| 34 | +class PerfMetricsLogCallback(TrainerCallback): |
| 35 | + """An callback for perf metrics (MFU etc) log implementation""" |
| 36 | + |
| 37 | + def __init__(self): |
| 38 | + self.device_tflops = None |
| 39 | + self.elapsed = 0.0 |
| 40 | + self.step_start_time = None |
| 41 | + |
| 42 | + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| 43 | + from swift.utils import get_current_device, get_device_count, get_env_args |
| 44 | + |
| 45 | + # Top priority. Specify by ENV |
| 46 | + tflops = get_env_args('DEVICE_TFLOPS', int, None) |
| 47 | + device_count = max(get_device_count(), 1) |
| 48 | + if tflops is not None: |
| 49 | + logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]") |
| 50 | + else: |
| 51 | + # Run a estimating test. |
| 52 | + dtype = kwargs.get('model').dtype |
| 53 | + device = torch.device(get_current_device()) |
| 54 | + logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]') |
| 55 | + tflops = self._estimate_device_tflops_by_dtype(device, dtype) |
| 56 | + logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]') |
| 57 | + # TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables. |
| 58 | + |
| 59 | + self.device_tflops = tflops * device_count |
| 60 | + |
| 61 | + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| 62 | + self.step_start_time = time.time() |
| 63 | + |
| 64 | + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| 65 | + self.elapsed += time.time() - self.step_start_time |
| 66 | + |
| 67 | + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): |
| 68 | + total_flos = getattr(state, 'total_flos', 0) |
| 69 | + actual_flops = total_flos / self.elapsed |
| 70 | + theoretical_max_flops = self.device_tflops * 1e12 |
| 71 | + mfu = actual_flops / theoretical_max_flops |
| 72 | + logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]') |
| 73 | + logs['MFU'] = round(mfu, 6) |
| 74 | + |
| 75 | + @staticmethod |
| 76 | + def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192): |
| 77 | + from swift.utils.torch_utils import empty_cache |
| 78 | + |
| 79 | + def device_synchronize(sync_device): |
| 80 | + if backend == 'cuda': |
| 81 | + torch.cuda.synchronize(sync_device) |
| 82 | + elif backend == 'npu': |
| 83 | + torch.npu.synchronize(sync_device) |
| 84 | + elif backend == 'cpu': |
| 85 | + torch.cpu.synchronize(sync_device) |
| 86 | + |
| 87 | + # Set matrix dimension |
| 88 | + shape = (dim, dim) |
| 89 | + backend = device.type |
| 90 | + if backend == 'npu': |
| 91 | + import torch_npu |
| 92 | + |
| 93 | + # Initialize matrices |
| 94 | + a = torch.randn(*shape, device=device, dtype=dtype) |
| 95 | + b = torch.randn(*shape, device=device, dtype=dtype) |
| 96 | + |
| 97 | + # Warm-up |
| 98 | + for _ in range(5): |
| 99 | + c = torch.matmul(a, b) |
| 100 | + device_synchronize(device) |
| 101 | + |
| 102 | + # Run benchmark test |
| 103 | + start = time.time() |
| 104 | + for _ in range(repeats): |
| 105 | + c = torch.matmul(a, b) |
| 106 | + device_synchronize(device) |
| 107 | + end = time.time() |
| 108 | + total_time = end - start |
| 109 | + avg_time = total_time / repeats |
| 110 | + |
| 111 | + # Adjust repeat count and retest if test duration is too short |
| 112 | + if total_time < 3: |
| 113 | + repeats = int(6 / avg_time) |
| 114 | + start = time.time() |
| 115 | + for _ in range(repeats): |
| 116 | + c = torch.matmul(a, b) |
| 117 | + device_synchronize(device) |
| 118 | + end = time.time() |
| 119 | + total_time = end - start |
| 120 | + avg_time = total_time / repeats |
| 121 | + |
| 122 | + del a, b, c |
| 123 | + empty_cache() |
| 124 | + |
| 125 | + tflops = (2 * dim**3 / avg_time) / 1e12 |
| 126 | + logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS') |
| 127 | + return tflops |
| 128 | + |
| 129 | + @staticmethod |
| 130 | + def _retrieve_flops_from_map(device): |
| 131 | + """Retrieve theoretical FLOPS from Map. """ |
| 132 | + |
| 133 | + device_name = device.get_device_name() |
| 134 | + flops = None |
| 135 | + for name, value in device_flops_map.items(): |
| 136 | + if name in device_name: |
| 137 | + flops = value |
| 138 | + break |
| 139 | + |
| 140 | + return flops |
| 141 | + |
| 142 | + |
| 143 | +device_flops_map = { |
| 144 | + 'GB200': 2.5e15, |
| 145 | + 'B200': 2.25e15, |
| 146 | + 'MI300X': 1336e12, |
| 147 | + 'H100': 312e12, |
| 148 | + 'H800': 312e12, |
| 149 | + 'H200': 989e12, |
| 150 | + 'A100': 312e12, |
| 151 | + 'A800': 312e12, |
| 152 | + 'L40S': 362.05e12, |
| 153 | + 'L40': 181.05e12, |
| 154 | + 'A40': 149.7e12, |
| 155 | + 'L20': 119.5e12, |
| 156 | + 'H20': 148e12, |
| 157 | + '910B': 354e12, |
| 158 | + 'Ascend910': 354e12, |
| 159 | + 'RTX 3070 Ti': 21.75e12 |
| 160 | +} |
| 161 | + |
31 | 162 | extra_callbacks = [] |
32 | 163 | # This example shows a simple example of EarlyStop Callback, uncomment this to use |
33 | 164 | # extra_callbacks = [EarlyStopCallback()] |
0 commit comments