@@ -218,48 +218,7 @@ def process_torch_outputs(obj):
218218 write_to_log ("paddle_error" , self .api_config .config )
219219 raise
220220
221- if self .api_config .api_name == "paddle.incubate.nn.functional.fused_rms_norm" :
222- paddle_output = paddle_output [0 ]
223- elif self .api_config .api_name == "paddle.unique" :
224- if "return_index=True" in self .api_config .config :
225- paddle_output = list (paddle_output )
226- paddle_output .pop (1 )
227- paddle_output = tuple (paddle_output )
228- elif self .api_config .api_name in {
229- "paddle.mode" ,
230- "paddle.Tensor.mode" ,
231- "paddle.incubate.nn.functional.fused_layer_norm" ,
232- "paddle.kthvalue" ,
233- "paddle.Tensor.kthvalue" ,
234- "paddle.topk" ,
235- }:
236- paddle_output = paddle_output [0 ]
237- torch_output = torch_output [0 ]
238- elif self .api_config .api_name in {
239- "paddle.strided_slice" ,
240- "paddle.vander" ,
241- } and any (s < 0 for s in paddle_output .strides ):
242- # torch's from_dlpack now don't support negative strides
243- paddle_output = paddle_output .contiguous ()
244- elif self .api_config .api_name == "paddle.linalg.eigh" :
245- # The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
246- # or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
247- # So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
248- # are all the same and whether the |coef| == 1 for simplicity.
249- paddle_output , torch_output = list (paddle_output ), list (torch_output )
250- eigvector_len = paddle_output [1 ].shape [- 2 ]
251- paddle_eigvectors = paddle_output .pop (1 ).matrix_transpose ().reshape ([- 1 , eigvector_len ])
252- torch_eigvectors = torch_output .pop (1 ).transpose (- 1 , - 2 ).reshape ((- 1 , eigvector_len ))
253- paddle_output , torch_output = [], []
254- for i in range (paddle_eigvectors .shape [0 ]):
255- coef_vector = paddle .to_tensor (paddle_eigvectors [i ].numpy ()/ torch_eigvectors [i ].numpy (), dtype = paddle_eigvectors [i ].dtype )
256- coef_vector = coef_vector .round (2 )
257- coef_0 = paddle_eigvectors [i ].numpy ()[0 ]/ torch_eigvectors [i ].numpy ()[0 ]
258- coef_vector_approx = torch .tensor ([coef_0 ] * eigvector_len )
259- abs_coef = coef_vector .abs ().astype ("float64" )[0 ]
260- one = torch .tensor (1.0 , dtype = torch .float64 )
261- paddle_output .append ([coef_vector , abs_coef ])
262- torch_output .append ([coef_vector_approx , one ])
221+ paddle_output , torch_output = process_output (self .api_config , paddle_output , torch_output )
263222
264223 self .is_backward = False
265224 def compare_paddle_and_torch (paddle_tensor , torch_tensor ) -> bool :
@@ -369,48 +328,7 @@ def compare_paddle_and_torch(paddle_tensor, torch_tensor) -> bool:
369328 write_to_log ("paddle_error" , self .api_config .config )
370329 raise
371330
372- if self .api_config .api_name == "paddle.Tensor.__setitem__" :
373- torch_out_grads = torch_out_grads [0 ]
374- paddle_out_grads = paddle_out_grads [0 ]
375-
376- # All configs that not compared with torch should be copied
377- # to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
378- if self .api_config .api_name in {
379- "paddle.nn.functional.scaled_dot_product_attention" ,
380- }:
381- paddle_out_grads = paddle_out_grads [:3 ]
382- torch_out_grads = torch_out_grads [:3 ]
383- elif self .api_config .api_name in {
384- "paddle.lerp" ,
385- "paddle.tensordot" ,
386- }:
387- paddle_out_grads = paddle_out_grads [:2 ]
388- torch_out_grads = torch_out_grads [:2 ]
389- elif self .api_config .api_name in {
390- "paddle.Tensor.fill_diagonal_tensor" ,
391- "paddle.diagonal_scatter" ,
392- "paddle.incubate.softmax_mask_fuse" ,
393- "paddle.nn.functional.binary_cross_entropy" ,
394- "paddle.nn.functional.binary_cross_entropy_with_logits" ,
395- "paddle.nn.functional.cross_entropy" ,
396- "paddle.nn.functional.sigmoid_focal_loss" ,
397- "paddle.nn.functional.gaussian_nll_loss" ,
398- "paddle.nn.functional.kl_div" ,
399- "paddle.scale" ,
400- }:
401- paddle_out_grads = paddle_out_grads [:1 ]
402- torch_out_grads = torch_out_grads [:1 ]
403- elif self .api_config .api_name in {
404- "paddle.combinations" ,
405- "paddle.nn.utils.parameters_to_vector" ,
406- "paddle.cdist" ,
407- }:
408- paddle_out_grads = []
409- torch_out_grads = []
410- elif self .api_config .api_name == "paddle.linalg.cholesky_solve" :
411- from .base import get_arg
412- is_upper = get_arg (self .api_config , 2 , 'upper' , default = False )
413- torch_out_grads [1 ] = torch .triu (torch_out_grads [1 ]) if is_upper else torch .tril (torch_out_grads [1 ])
331+ paddle_out_grads , torch_out_grads = process_grad_output (self .api_config , paddle_out_grads , torch_out_grads )
414332
415333 if isinstance (paddle_out_grads , paddle .Tensor ):
416334 if isinstance (torch_out_grads , torch .Tensor ):
@@ -447,3 +365,106 @@ def compare_paddle_and_torch(paddle_tensor, torch_tensor) -> bool:
447365
448366 print ("[Pass]" , self .api_config .config , flush = True )
449367 write_to_log ("pass" , self .api_config .config )
368+
369+
370+ def process_output (api_config , paddle_output , torch_output ):
371+ if api_config .api_name == "paddle.unique" :
372+ if "return_index=True" in api_config .config :
373+ paddle_output = list (paddle_output )
374+ paddle_output .pop (1 )
375+ elif api_config .api_name in {
376+ "paddle.mode" ,
377+ "paddle.Tensor.mode" ,
378+ "paddle.incubate.nn.functional.fused_layer_norm" ,
379+ "paddle.incubate.nn.functional.fused_rms_norm" ,
380+ "paddle.kthvalue" ,
381+ "paddle.Tensor.kthvalue" ,
382+ "paddle.topk" ,
383+ }:
384+ paddle_output = paddle_output [:1 ]
385+ torch_output = torch_output [:1 ]
386+ elif api_config .api_name in {
387+ "paddle.strided_slice" ,
388+ "paddle.vander" ,
389+ }:
390+ if any (s < 0 for s in paddle_output .strides ):
391+ # torch's from_dlpack now don't support negative strides
392+ paddle_output = paddle_output .contiguous ()
393+ elif api_config .api_name == "paddle.linalg.eigh" :
394+ # The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
395+ # or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
396+ # So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
397+ # are all the same and whether the |coef| == 1 for simplicity.
398+ paddle_output , torch_output = list (paddle_output ), list (torch_output )
399+ eigvector_len = paddle_output [1 ].shape [- 2 ]
400+ paddle_eigvectors = (
401+ paddle_output .pop (1 ).matrix_transpose ().reshape ([- 1 , eigvector_len ])
402+ )
403+ torch_eigvectors = (
404+ torch_output .pop (1 ).transpose (- 1 , - 2 ).reshape ((- 1 , eigvector_len ))
405+ )
406+ paddle_output , torch_output = [], []
407+ for i in range (paddle_eigvectors .shape [0 ]):
408+ coef_vector = paddle .to_tensor (
409+ paddle_eigvectors [i ].numpy () / torch_eigvectors [i ].numpy (),
410+ dtype = paddle_eigvectors [i ].dtype ,
411+ )
412+ coef_vector = coef_vector .round (2 )
413+ coef_0 = paddle_eigvectors [i ].numpy ()[0 ] / torch_eigvectors [i ].numpy ()[0 ]
414+ coef_vector_approx = torch .tensor ([coef_0 ] * eigvector_len )
415+ abs_coef = coef_vector .abs ().astype ("float64" )[0 ]
416+ one = torch .tensor (1.0 , dtype = torch .float64 )
417+ paddle_output .append ([coef_vector , abs_coef ])
418+ torch_output .append ([coef_vector_approx , one ])
419+ return paddle_output , torch_output
420+
421+
422+ def process_grad_output (api_config , paddle_out_grads , torch_out_grads ):
423+ # All configs that not compared with torch should be copied
424+ # to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
425+ if api_config .api_name in {
426+ "paddle.nn.functional.scaled_dot_product_attention" ,
427+ }:
428+ paddle_out_grads = paddle_out_grads [:3 ]
429+ torch_out_grads = torch_out_grads [:3 ]
430+ elif api_config .api_name in {
431+ "paddle.lerp" ,
432+ "paddle.tensordot" ,
433+ }:
434+ paddle_out_grads = paddle_out_grads [:2 ]
435+ torch_out_grads = torch_out_grads [:2 ]
436+ elif api_config .api_name in {
437+ "paddle.Tensor.__setitem__" ,
438+ "paddle.Tensor.fill_diagonal_tensor" ,
439+ "paddle.diagonal_scatter" ,
440+ "paddle.incubate.softmax_mask_fuse" ,
441+ "paddle.nn.functional.binary_cross_entropy" ,
442+ "paddle.nn.functional.binary_cross_entropy_with_logits" ,
443+ "paddle.nn.functional.cross_entropy" ,
444+ "paddle.nn.functional.gaussian_nll_loss" ,
445+ "paddle.nn.functional.kl_div" ,
446+ "paddle.nn.functional.sigmoid_focal_loss" ,
447+ "paddle.scale" ,
448+ }:
449+ paddle_out_grads = paddle_out_grads [:1 ]
450+ torch_out_grads = torch_out_grads [:1 ]
451+ elif api_config .api_name in {
452+ "paddle.combinations" ,
453+ "paddle.nn.utils.parameters_to_vector" ,
454+ "paddle.cdist" ,
455+ }:
456+ paddle_out_grads = []
457+ torch_out_grads = []
458+ elif api_config .api_name == "paddle.linalg.cholesky_solve" :
459+ if len (api_config .args ) > 2 :
460+ is_upper = api_config .args [2 ]
461+ elif "is_upper" in api_config .kwargs :
462+ is_upper = api_config .kwargs ["is_upper" ]
463+ else :
464+ is_upper = False
465+ torch_out_grads [1 ] = (
466+ torch .triu (torch_out_grads [1 ])
467+ if is_upper
468+ else torch .tril (torch_out_grads [1 ])
469+ )
470+ return paddle_out_grads , torch_out_grads
0 commit comments