@@ -380,26 +380,31 @@ def _eval_models(
380380 m2 : DistributedModelParallel ,
381381 batch : ModelInput ,
382382 is_deterministic : bool = True ,
383+ tolerance : Optional [float ] = None ,
383384 ) -> None :
384385 with torch .no_grad ():
385386 loss1 , pred1 = m1 (batch )
386387 loss2 , pred2 = m2 (batch )
387-
388388 if is_deterministic :
389389 self .assertTrue (torch .equal (loss1 , loss2 ))
390390 self .assertTrue (torch .equal (pred1 , pred2 ))
391391 else :
392- rtol , atol = _get_default_rtol_and_atol (loss1 , loss2 )
393- torch .testing .assert_close (loss1 , loss2 , rtol = rtol , atol = atol )
394- rtol , atol = _get_default_rtol_and_atol (pred1 , pred2 )
395- torch .testing .assert_close (pred1 , pred2 , rtol = rtol , atol = atol )
392+ if tolerance :
393+ torch .testing .assert_close (loss1 , loss2 , rtol = tolerance , atol = tolerance )
394+ torch .testing .assert_close (pred1 , pred2 , rtol = tolerance , atol = tolerance )
395+ else :
396+ rtol , atol = _get_default_rtol_and_atol (loss1 , loss2 )
397+ torch .testing .assert_close (loss1 , loss2 , rtol = rtol , atol = atol )
398+ rtol , atol = _get_default_rtol_and_atol (pred1 , pred2 )
399+ torch .testing .assert_close (pred1 , pred2 , rtol = rtol , atol = atol )
396400
397401 def _compare_models (
398402 self ,
399403 m1 : DistributedModelParallel ,
400404 m2 : DistributedModelParallel ,
401405 is_deterministic : bool = True ,
402406 use_virtual_table : bool = False ,
407+ tolerance : Optional [float ] = None ,
403408 ) -> None :
404409 sd1 = m1 .state_dict ()
405410 sd2 = m2 .state_dict ()
@@ -437,7 +442,12 @@ def _compare_models(
437442 if is_deterministic :
438443 self .assertTrue (torch .allclose (src_tensor , dst_tensor ))
439444 else :
440- rtol , atol = _get_default_rtol_and_atol (src_tensor , dst_tensor )
445+ if tolerance :
446+ rtol , atol = tolerance , tolerance
447+ else :
448+ rtol , atol = _get_default_rtol_and_atol (
449+ src_tensor , dst_tensor
450+ )
441451 torch .testing .assert_close (
442452 src_tensor , dst_tensor , rtol = rtol , atol = atol
443453 )
@@ -453,7 +463,10 @@ def _compare_models(
453463 if is_deterministic :
454464 self .assertTrue (torch .equal (src , dst ))
455465 else :
456- rtol , atol = _get_default_rtol_and_atol (src , dst )
466+ if tolerance :
467+ rtol , atol = tolerance , tolerance
468+ else :
469+ rtol , atol = _get_default_rtol_and_atol (src , dst )
457470 torch .testing .assert_close (
458471 src ._local_tensor , dst ._local_tensor , rtol = rtol , atol = atol
459472 )
@@ -463,7 +476,10 @@ def _compare_models(
463476 if is_deterministic :
464477 self .assertTrue (torch .equal (src , dst ))
465478 else :
466- rtol , atol = _get_default_rtol_and_atol (src , dst )
479+ if tolerance :
480+ rtol , atol = tolerance , tolerance
481+ else :
482+ rtol , atol = _get_default_rtol_and_atol (src , dst )
467483 torch .testing .assert_close (src , dst , rtol = rtol , atol = atol )
468484
469485
0 commit comments