Skip to content

Commit 1152dbb

Browse files
Merge pull request #541 from cangtianhuang/develop
Update
2 parents 44e074c + c138dc0 commit 1152dbb

File tree

7 files changed

+189
-191
lines changed

7 files changed

+189
-191
lines changed

engineV2-README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
```
3232
3. 确保 `engineV2.py``log_writer.py``run.sh` 路径正确
3333

34-
> [!CAUTION]
34+
> [!CAUTION] CAUTION 1
3535
> 目前 engineV2 仅支持 ***python>=3.10***,如报错 *`NameError: name 'torch' is not defined`*,请在 `run_test_case()` 函数首行手动添加导入语句:
3636
> ```python
3737
> import torch
@@ -46,6 +46,14 @@
4646
>
4747
> 由于 gpu 隔离的需求,主进程并没有导入 `torch``paddle`,所以当 `run_test_case` 在子进程中被反序列化并准备执行时,它找不到这些库,从而引发 `NameError`
4848
49+
> [!CAUTION] CAUTION 2
50+
> 在更高 CUDA 版本下,比如 CUDA Version: 12.9,可能会出现以下报错:
51+
> ```bash
52+
> ImportError: /usr/local/cuda/lib64/libcusparse.so.12: undefined symbol: __nvJitLinkGetErrorLogSize_12_9, version libnvJitLink.so.12
53+
> ```
54+
>
55+
> 解决方案是在 `init_worker_gpu()` 函数中交换 `torch``paddle` 的导入顺序,将 `torch` 置于 `paddle` 之前,具体原因未知
56+
4957
## 使用指南
5058
5159
### 命令行参数

tester/accuracy.py

Lines changed: 105 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -218,48 +218,7 @@ def process_torch_outputs(obj):
218218
write_to_log("paddle_error", self.api_config.config)
219219
raise
220220

221-
if self.api_config.api_name == "paddle.incubate.nn.functional.fused_rms_norm":
222-
paddle_output = paddle_output[0]
223-
elif self.api_config.api_name == "paddle.unique":
224-
if "return_index=True" in self.api_config.config:
225-
paddle_output = list(paddle_output)
226-
paddle_output.pop(1)
227-
paddle_output = tuple(paddle_output)
228-
elif self.api_config.api_name in {
229-
"paddle.mode",
230-
"paddle.Tensor.mode",
231-
"paddle.incubate.nn.functional.fused_layer_norm",
232-
"paddle.kthvalue",
233-
"paddle.Tensor.kthvalue",
234-
"paddle.topk",
235-
}:
236-
paddle_output = paddle_output[0]
237-
torch_output = torch_output[0]
238-
elif self.api_config.api_name in {
239-
"paddle.strided_slice",
240-
"paddle.vander",
241-
} and any(s < 0 for s in paddle_output.strides):
242-
# torch's from_dlpack now don't support negative strides
243-
paddle_output = paddle_output.contiguous()
244-
elif self.api_config.api_name == "paddle.linalg.eigh":
245-
# The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
246-
# or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
247-
# So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
248-
# are all the same and whether the |coef| == 1 for simplicity.
249-
paddle_output, torch_output = list(paddle_output), list(torch_output)
250-
eigvector_len = paddle_output[1].shape[-2]
251-
paddle_eigvectors = paddle_output.pop(1).matrix_transpose().reshape([-1, eigvector_len])
252-
torch_eigvectors = torch_output.pop(1).transpose(-1, -2).reshape((-1, eigvector_len))
253-
paddle_output, torch_output = [], []
254-
for i in range(paddle_eigvectors.shape[0]):
255-
coef_vector = paddle.to_tensor(paddle_eigvectors[i].numpy()/torch_eigvectors[i].numpy(), dtype=paddle_eigvectors[i].dtype)
256-
coef_vector = coef_vector.round(2)
257-
coef_0 = paddle_eigvectors[i].numpy()[0]/torch_eigvectors[i].numpy()[0]
258-
coef_vector_approx = torch.tensor([coef_0] * eigvector_len)
259-
abs_coef = coef_vector.abs().astype("float64")[0]
260-
one = torch.tensor(1.0, dtype=torch.float64)
261-
paddle_output.append([coef_vector, abs_coef])
262-
torch_output.append([coef_vector_approx, one])
221+
paddle_output, torch_output = process_output(self.api_config, paddle_output, torch_output)
263222

264223
self.is_backward = False
265224
def compare_paddle_and_torch(paddle_tensor, torch_tensor) -> bool:
@@ -369,48 +328,7 @@ def compare_paddle_and_torch(paddle_tensor, torch_tensor) -> bool:
369328
write_to_log("paddle_error", self.api_config.config)
370329
raise
371330

372-
if self.api_config.api_name == "paddle.Tensor.__setitem__":
373-
torch_out_grads = torch_out_grads[0]
374-
paddle_out_grads = paddle_out_grads[0]
375-
376-
# All configs that not compared with torch should be copied
377-
# to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
378-
if self.api_config.api_name in {
379-
"paddle.nn.functional.scaled_dot_product_attention",
380-
}:
381-
paddle_out_grads = paddle_out_grads[:3]
382-
torch_out_grads = torch_out_grads[:3]
383-
elif self.api_config.api_name in {
384-
"paddle.lerp",
385-
"paddle.tensordot",
386-
}:
387-
paddle_out_grads = paddle_out_grads[:2]
388-
torch_out_grads = torch_out_grads[:2]
389-
elif self.api_config.api_name in {
390-
"paddle.Tensor.fill_diagonal_tensor",
391-
"paddle.diagonal_scatter",
392-
"paddle.incubate.softmax_mask_fuse",
393-
"paddle.nn.functional.binary_cross_entropy",
394-
"paddle.nn.functional.binary_cross_entropy_with_logits",
395-
"paddle.nn.functional.cross_entropy",
396-
"paddle.nn.functional.sigmoid_focal_loss",
397-
"paddle.nn.functional.gaussian_nll_loss",
398-
"paddle.nn.functional.kl_div",
399-
"paddle.scale",
400-
}:
401-
paddle_out_grads = paddle_out_grads[:1]
402-
torch_out_grads = torch_out_grads[:1]
403-
elif self.api_config.api_name in {
404-
"paddle.combinations",
405-
"paddle.nn.utils.parameters_to_vector",
406-
"paddle.cdist",
407-
}:
408-
paddle_out_grads = []
409-
torch_out_grads = []
410-
elif self.api_config.api_name == "paddle.linalg.cholesky_solve":
411-
from .base import get_arg
412-
is_upper = get_arg(self.api_config, 2, 'upper', default=False)
413-
torch_out_grads[1] = torch.triu(torch_out_grads[1]) if is_upper else torch.tril(torch_out_grads[1])
331+
paddle_out_grads, torch_out_grads = process_grad_output(self.api_config, paddle_out_grads, torch_out_grads)
414332

415333
if isinstance(paddle_out_grads, paddle.Tensor):
416334
if isinstance(torch_out_grads, torch.Tensor):
@@ -447,3 +365,106 @@ def compare_paddle_and_torch(paddle_tensor, torch_tensor) -> bool:
447365

448366
print("[Pass]", self.api_config.config, flush=True)
449367
write_to_log("pass", self.api_config.config)
368+
369+
370+
def process_output(api_config, paddle_output, torch_output):
371+
if api_config.api_name == "paddle.unique":
372+
if "return_index=True" in api_config.config:
373+
paddle_output = list(paddle_output)
374+
paddle_output.pop(1)
375+
elif api_config.api_name in {
376+
"paddle.mode",
377+
"paddle.Tensor.mode",
378+
"paddle.incubate.nn.functional.fused_layer_norm",
379+
"paddle.incubate.nn.functional.fused_rms_norm",
380+
"paddle.kthvalue",
381+
"paddle.Tensor.kthvalue",
382+
"paddle.topk",
383+
}:
384+
paddle_output = paddle_output[:1]
385+
torch_output = torch_output[:1]
386+
elif api_config.api_name in {
387+
"paddle.strided_slice",
388+
"paddle.vander",
389+
}:
390+
if any(s < 0 for s in paddle_output.strides):
391+
# torch's from_dlpack now don't support negative strides
392+
paddle_output = paddle_output.contiguous()
393+
elif api_config.api_name == "paddle.linalg.eigh":
394+
# The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
395+
# or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
396+
# So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
397+
# are all the same and whether the |coef| == 1 for simplicity.
398+
paddle_output, torch_output = list(paddle_output), list(torch_output)
399+
eigvector_len = paddle_output[1].shape[-2]
400+
paddle_eigvectors = (
401+
paddle_output.pop(1).matrix_transpose().reshape([-1, eigvector_len])
402+
)
403+
torch_eigvectors = (
404+
torch_output.pop(1).transpose(-1, -2).reshape((-1, eigvector_len))
405+
)
406+
paddle_output, torch_output = [], []
407+
for i in range(paddle_eigvectors.shape[0]):
408+
coef_vector = paddle.to_tensor(
409+
paddle_eigvectors[i].numpy() / torch_eigvectors[i].numpy(),
410+
dtype=paddle_eigvectors[i].dtype,
411+
)
412+
coef_vector = coef_vector.round(2)
413+
coef_0 = paddle_eigvectors[i].numpy()[0] / torch_eigvectors[i].numpy()[0]
414+
coef_vector_approx = torch.tensor([coef_0] * eigvector_len)
415+
abs_coef = coef_vector.abs().astype("float64")[0]
416+
one = torch.tensor(1.0, dtype=torch.float64)
417+
paddle_output.append([coef_vector, abs_coef])
418+
torch_output.append([coef_vector_approx, one])
419+
return paddle_output, torch_output
420+
421+
422+
def process_grad_output(api_config, paddle_out_grads, torch_out_grads):
423+
# All configs that not compared with torch should be copied
424+
# to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
425+
if api_config.api_name in {
426+
"paddle.nn.functional.scaled_dot_product_attention",
427+
}:
428+
paddle_out_grads = paddle_out_grads[:3]
429+
torch_out_grads = torch_out_grads[:3]
430+
elif api_config.api_name in {
431+
"paddle.lerp",
432+
"paddle.tensordot",
433+
}:
434+
paddle_out_grads = paddle_out_grads[:2]
435+
torch_out_grads = torch_out_grads[:2]
436+
elif api_config.api_name in {
437+
"paddle.Tensor.__setitem__",
438+
"paddle.Tensor.fill_diagonal_tensor",
439+
"paddle.diagonal_scatter",
440+
"paddle.incubate.softmax_mask_fuse",
441+
"paddle.nn.functional.binary_cross_entropy",
442+
"paddle.nn.functional.binary_cross_entropy_with_logits",
443+
"paddle.nn.functional.cross_entropy",
444+
"paddle.nn.functional.gaussian_nll_loss",
445+
"paddle.nn.functional.kl_div",
446+
"paddle.nn.functional.sigmoid_focal_loss",
447+
"paddle.scale",
448+
}:
449+
paddle_out_grads = paddle_out_grads[:1]
450+
torch_out_grads = torch_out_grads[:1]
451+
elif api_config.api_name in {
452+
"paddle.combinations",
453+
"paddle.nn.utils.parameters_to_vector",
454+
"paddle.cdist",
455+
}:
456+
paddle_out_grads = []
457+
torch_out_grads = []
458+
elif api_config.api_name == "paddle.linalg.cholesky_solve":
459+
if len(api_config.args) > 2:
460+
is_upper = api_config.args[2]
461+
elif "is_upper" in api_config.kwargs:
462+
is_upper = api_config.kwargs["is_upper"]
463+
else:
464+
is_upper = False
465+
torch_out_grads[1] = (
466+
torch.triu(torch_out_grads[1])
467+
if is_upper
468+
else torch.tril(torch_out_grads[1])
469+
)
470+
return paddle_out_grads, torch_out_grads

tester/accuracy_stable.py

Lines changed: 11 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .api_config.log_writer import log_accuracy_stable, write_to_log
88
from .base import APITestBase
99
from .paddle_to_torch import get_converter
10+
from .accuracy import process_output, process_grad_output
1011

1112

1213
cuda_errors = frozenset(
@@ -94,10 +95,11 @@ def test(self):
9495
paddle.device.cuda.empty_cache()
9596

9697
# ======== format ========
97-
torch_output, torch_out_grads, paddle_output, paddle_out_grads = (
98-
self.format_output(
99-
torch_output, torch_out_grads, paddle_output, paddle_out_grads
100-
)
98+
paddle_output, torch_output = process_output(
99+
self.api_config, paddle_output, torch_output
100+
)
101+
paddle_out_grads, torch_out_grads = process_grad_output(
102+
self.api_config, paddle_out_grads, torch_out_grads
101103
)
102104

103105
# ======== add to pair ========
@@ -348,101 +350,6 @@ def process_paddle_outputs(obj):
348350
paddle_out_grads = process_paddle_outputs(paddle_out_grads)
349351
return paddle_output, paddle_out_grads
350352

351-
def format_output(
352-
self, paddle_output, torch_output, paddle_out_grads, torch_out_grads
353-
):
354-
# ======== format output ========
355-
if self.api_config.api_name == "paddle.incubate.nn.functional.fused_rms_norm":
356-
paddle_output = paddle_output[0]
357-
elif self.api_config.api_name == "paddle.unique":
358-
if "return_index=True" in self.api_config.config:
359-
paddle_output = list(paddle_output)
360-
paddle_output.pop(1)
361-
paddle_output = tuple(paddle_output)
362-
elif self.api_config.api_name in {
363-
"paddle.mode",
364-
"paddle.Tensor.mode",
365-
"paddle.incubate.nn.functional.fused_layer_norm",
366-
"paddle.kthvalue",
367-
}:
368-
paddle_output = paddle_output[0]
369-
torch_output = torch_output[0]
370-
elif self.api_config.api_name in {
371-
"paddle.strided_slice",
372-
"paddle.vander",
373-
} and any(s < 0 for s in paddle_output.strides):
374-
# torch's from_dlpack now don't support negative strides
375-
paddle_output = paddle_output.contiguous()
376-
elif self.api_config.api_name == "paddle.linalg.eigh":
377-
# The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
378-
# or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
379-
# So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
380-
# are all the same and whether the |coef| == 1 for simplicity.
381-
paddle_output, torch_output = list(paddle_output), list(torch_output)
382-
eigvector_len = paddle_output[1].shape[-2]
383-
paddle_eigvectors = (
384-
paddle_output.pop(1).matrix_transpose().reshape([-1, eigvector_len])
385-
)
386-
torch_eigvectors = (
387-
torch_output.pop(1).transpose(-1, -2).reshape((-1, eigvector_len))
388-
)
389-
paddle_output, torch_output = [], []
390-
for i in range(paddle_eigvectors.shape[0]):
391-
coef_vector = paddle.to_tensor(
392-
paddle_eigvectors[i].numpy() / torch_eigvectors[i].numpy(),
393-
dtype=paddle_eigvectors[i].dtype,
394-
)
395-
coef_vector = coef_vector.round(2)
396-
coef_0 = (
397-
paddle_eigvectors[i].numpy()[0] / torch_eigvectors[i].numpy()[0]
398-
)
399-
coef_vector_approx = torch.tensor([coef_0] * eigvector_len)
400-
abs_coef = coef_vector.abs().astype("float64")[0]
401-
one = torch.tensor(1.0, dtype=torch.float64)
402-
paddle_output.append([coef_vector, abs_coef])
403-
torch_output.append([coef_vector_approx, one])
404-
405-
# ======== format gradient ========
406-
if self.api_config.api_name == "paddle.Tensor.__setitem__":
407-
torch_out_grads = torch_out_grads[0]
408-
paddle_out_grads = paddle_out_grads[0]
409-
# All configs that not compared with torch should be copied
410-
# to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
411-
if self.api_config.api_name in {
412-
"paddle.nn.functional.scaled_dot_product_attention",
413-
}:
414-
paddle_out_grads = paddle_out_grads[:3]
415-
torch_out_grads = torch_out_grads[:3]
416-
elif self.api_config.api_name in {
417-
"paddle.lerp",
418-
"paddle.tensordot",
419-
}:
420-
paddle_out_grads = paddle_out_grads[:2]
421-
torch_out_grads = torch_out_grads[:2]
422-
elif self.api_config.api_name in {
423-
"paddle.Tensor.fill_diagonal_tensor",
424-
"paddle.diagonal_scatter",
425-
"paddle.incubate.softmax_mask_fuse",
426-
"paddle.nn.functional.binary_cross_entropy",
427-
"paddle.nn.functional.binary_cross_entropy_with_logits",
428-
"paddle.nn.functional.cross_entropy",
429-
"paddle.nn.functional.sigmoid_focal_loss",
430-
"paddle.nn.functional.gaussian_nll_loss",
431-
"paddle.nn.functional.kl_div",
432-
"paddle.scale",
433-
}:
434-
paddle_out_grads = paddle_out_grads[:1]
435-
torch_out_grads = torch_out_grads[:1]
436-
elif self.api_config.api_name in {
437-
"paddle.combinations",
438-
"paddle.nn.utils.parameters_to_vector",
439-
"paddle.cdist",
440-
}:
441-
paddle_out_grads = []
442-
torch_out_grads = []
443-
444-
return paddle_output, torch_output, paddle_out_grads, torch_out_grads
445-
446353
def compare(self, input1, input2, comp):
447354
if isinstance(input1, (paddle.Tensor, torch.Tensor)):
448355
if isinstance(input2, (paddle.Tensor, torch.Tensor)):
@@ -481,12 +388,12 @@ def compare(self, input1, input2, comp):
481388
)
482389
write_to_log("accuracy_error", self.api_config.config)
483390
return
484-
for item1, item2 in zip(input1, input2):
391+
for idx, (item1, item2) in enumerate(zip(input1, input2)):
485392
if isinstance(item1, (paddle.Tensor, torch.Tensor)) and isinstance(
486393
item2, (paddle.Tensor, torch.Tensor)
487394
):
488395
try:
489-
self.assert_accuracy(item1, item2, comp)
396+
self.assert_accuracy(item1, item2, comp, idx)
490397
except Exception as err:
491398
print(
492399
f"[{comp}] [accuracy error] {self.api_config.config}\n{str(err)}",
@@ -499,7 +406,7 @@ def compare(self, input1, input2, comp):
499406
) and not isinstance(item2, (paddle.Tensor, torch.Tensor)):
500407
try:
501408
self.assert_accuracy(
502-
torch.tensor(item1), torch.tensor(item2), comp
409+
torch.tensor(item1), torch.tensor(item2), comp, idx
503410
)
504411
except Exception as err:
505412
print(
@@ -527,7 +434,7 @@ def compare(self, input1, input2, comp):
527434
write_to_log("accuracy_error", self.api_config.config)
528435
return
529436

530-
def assert_accuracy(self, tensor1, tensor2, comp):
437+
def assert_accuracy(self, tensor1, tensor2, comp, idx=0):
531438
if not tensor1.is_contiguous():
532439
tensor1 = tensor1.contiguous()
533440
if not tensor2.is_contiguous():
@@ -619,5 +526,6 @@ def error_msg(msg):
619526
dtype,
620527
comp,
621528
)
529+
write_to_log("accuracy_diff", config)
622530
else:
623531
raise

0 commit comments

Comments
 (0)