Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions swift/plugin/callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import time

import numpy as np
import torch
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments

from swift.utils import get_logger
Expand Down Expand Up @@ -28,6 +31,134 @@ def on_save(self, args: TrainingArguments, state: TrainerState, control: Trainer
control.should_training_stop = True


extra_callbacks = []
class PerfMetricsLogCallback(TrainerCallback):
"""An callback for perf metrics (MFU etc) log implementation"""

def __init__(self):
self.device_tflops = None
self.elapsed = 0.0
self.step_start_time = None

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
from swift.utils import get_current_device, get_device_count, get_env_args

# Top priority. Specify by ENV
tflops = get_env_args('DEVICE_TFLOPS', int, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEVICE_TFLOPS is parsed as an integer, which might be too restrictive as TFLOPS values are often floating-point numbers (e.g., from the estimation or lookup table). Using float would be more appropriate and consistent.

Suggested change
tflops = get_env_args('DEVICE_TFLOPS', int, None)
tflops = get_env_args('DEVICE_TFLOPS', float, None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems float is correct.

device_count = max(get_device_count(), 1)
if tflops is not None:
logger.info(f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{tflops} TFLOPS]")
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
Comment on lines +50 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _retrieve_flops_from_map function provides an efficient way to get device TFLOPS from a lookup table. It should be called as a fallback before running the performance estimation test. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

Suggested change
else:
# Run a estimating test.
dtype = kwargs.get('model').dtype
device = torch.device(get_current_device())
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')
else:
device = torch.device(get_current_device())
tflops = self._retrieve_flops_from_map(device)
if tflops is not None:
device_name = torch.cuda.get_device_name(device) if device.type == 'cuda' else str(device)
logger.info(f'Retrieved TFLOPS from lookup table for {device_name}: {tflops} TFLOPS')
else:
# Run an estimating test.
dtype = kwargs.get('model').dtype
logger.info(f'Estimating device TFLOPS baseline. Device: [{device}] dtype: [{dtype}]')
tflops = self._estimate_device_tflops_by_dtype(device, dtype)
logger.info(f'Estimate test finished. [{tflops} TFLOPS] Device count: [{device_count}]')

# TODO Collect comprehensive TFLOPS data. Then provide a fallback strategy based on lookup tables.

self.device_tflops = tflops * device_count

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.step_start_time = time.time()

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.elapsed += time.time() - self.step_start_time

def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0)
actual_flops = total_flos / self.elapsed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential ZeroDivisionError here if self.elapsed is 0. This could happen if on_log is called before on_step_end has completed. It's safer to add a guard for this case.

Suggested change
actual_flops = total_flos / self.elapsed
actual_flops = total_flos / self.elapsed if self.elapsed > 0 else 0

theoretical_max_flops = self.device_tflops * 1e12
mfu = actual_flops / theoretical_max_flops
logger.debug(f'Total_flos[{total_flos}] elapsed_time[{self.elapsed}]sec Average MFU[{mfu}]')
logs['MFU'] = round(mfu, 6)

@staticmethod
def _estimate_device_tflops_by_dtype(device: torch.device, dtype: torch.dtype, repeats: int = 60, dim: int = 8192):
from swift.utils.torch_utils import empty_cache

def device_synchronize(sync_device):
if backend == 'cuda':
torch.cuda.synchronize(sync_device)
elif backend == 'npu':
torch.npu.synchronize(sync_device)
elif backend == 'cpu':
torch.cpu.synchronize(sync_device)
Comment on lines +84 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The helper function device_synchronize has a bug when the backend is 'cpu'. torch.cpu.synchronize does not exist in PyTorch, and calling it will cause an AttributeError. CPU operations are typically synchronous, so an explicit synchronization is not needed.

Suggested change
elif backend == 'cpu':
torch.cpu.synchronize(sync_device)
elif backend == 'cpu':
pass # CPU operations are synchronous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function torch.cpu.synchronize exists and have been tested.


# 默认矩阵规模
shape = (dim, dim)
backend = device.type
if backend == 'npu':
import torch_npu

# 创建矩阵
a = torch.randn(*shape, device=device, dtype=dtype)
b = torch.randn(*shape, device=device, dtype=dtype)

# 预热
for _ in range(5):
c = torch.matmul(a, b)
device_synchronize(device)

# 进行测试
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
device_synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

# 若测试时间过短,调整循环次数并重新测试
if total_time < 3:
repeats = int(6 / avg_time)
start = time.time()
for _ in range(repeats):
c = torch.matmul(a, b)
device_synchronize(device)
end = time.time()
total_time = end - start
avg_time = total_time / repeats

del a, b, c
empty_cache()

tflops = (2 * dim**3 / avg_time) / 1e12
logger.info(f'[Device {device}] Total time: {total_time:.4f}s, dtype: {dtype}, Perf: {tflops:.4f} TFLOPS')
return tflops

@staticmethod
def _retrieve_flops_from_map(device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this function is not being used. Just curious, what's the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More data should be collected, and hopefully the function _retrieve_flops_from_map can become a more efficient and accurate way to obtain device TFLOPS. This avoids running a potentially time-consuming benchmark if the device's performance is already known.

"""Retrieve theoretical FLOPS from Map. """

device_name = device.get_device_name()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the get_device_name function?

flops = None
for name, value in device_flops_map.items():
if name in device_name:
flops = value
break

return flops
Comment on lines +130 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This method has a few issues and is currently not used:

  1. device.get_device_name() will raise an AttributeError as torch.device objects do not have this method. You should use torch.cuda.get_device_name(device) for CUDA devices and handle other device types gracefully.
  2. The values in device_flops_map are in FLOPS, but this function should return TFLOPS to be consistent with other parts of the callback. The result should be divided by 1e12.
  3. This function is currently unused. It should be called in on_init_end to provide an efficient way to determine TFLOPS.
    def _retrieve_flops_from_map(device):
        """Retrieve theoretical FLOPS from Map and convert to TFLOPS."""
        if device.type != 'cuda':
            # Add other supported device types like 'npu' if needed
            return None

        device_name = torch.cuda.get_device_name(device)
        flops = None
        for name, value in device_flops_map.items():
            if name in device_name:
                flops = value
                break

        if flops is not None:
            return flops / 1e12
        return None



device_flops_map = {
'GB200': 2.5e15,
'B200': 2.25e15,
'MI300X': 1336e12,
'H100': 312e12,
'H800': 312e12,
'H200': 989e12,
'A100': 312e12,
'A800': 312e12,
'L40S': 362.05e12,
'L40': 181.05e12,
'A40': 149.7e12,
'L20': 119.5e12,
'H20': 148e12,
'910B': 354e12,
'Ascend910': 354e12,
'RTX 3070 Ti': 21.75e12
}

extra_callbacks = [PerfMetricsLogCallback()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not enable by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Is a switch argument required?

# This example shows a simple example of EarlyStop Callback, uncomment this to use
# extra_callbacks = [EarlyStopCallback()]
# extra_callbacks = [EarlyStopCallback()]