18
18
import unittest
19
19
20
20
from transformers import DetrConfig , ResNetConfig , is_torch_available , is_vision_available
21
- from transformers .testing_utils import require_timm , require_torch , require_vision , slow , torch_device
21
+ from transformers .testing_utils import Expectations , require_timm , require_torch , require_vision , slow , torch_device
22
22
from transformers .utils import cached_property
23
23
24
24
from ...test_configuration_common import ConfigTester
@@ -585,13 +585,23 @@ def test_inference_no_head(self):
585
585
586
586
expected_shape = torch .Size ((1 , 100 , 256 ))
587
587
assert outputs .last_hidden_state .shape == expected_shape
588
- expected_slice = torch .tensor (
589
- [
590
- [0.0622 , - 0.5142 , - 0.4034 ],
591
- [- 0.7628 , - 0.4935 , - 1.7153 ],
592
- [- 0.4751 , - 0.6386 , - 0.7818 ],
593
- ]
594
- ).to (torch_device )
588
+ expected_slices = Expectations (
589
+ {
590
+ (None , None ):
591
+ [
592
+ [0.0622 , - 0.5142 , - 0.4034 ],
593
+ [- 0.7628 , - 0.4935 , - 1.7153 ],
594
+ [- 0.4751 , - 0.6386 , - 0.7818 ],
595
+ ],
596
+ ("rocm" , (9 , 5 )):
597
+ [
598
+ [ 0.0616 , - 0.5146 , - 0.4032 ],
599
+ [- 0.7629 , - 0.4934 , - 1.7153 ],
600
+ [- 0.4768 , - 0.6403 , - 0.7826 ],
601
+ ],
602
+ }
603
+ ) # fmt: skip
604
+ expected_slice = torch .tensor (expected_slices .get_expectation (), device = torch_device )
595
605
torch .testing .assert_close (outputs .last_hidden_state [0 , :3 , :3 ], expected_slice , rtol = 2e-4 , atol = 2e-4 )
596
606
597
607
def test_inference_object_detection_head (self ):
@@ -609,13 +619,23 @@ def test_inference_object_detection_head(self):
609
619
# verify outputs
610
620
expected_shape_logits = torch .Size ((1 , model .config .num_queries , model .config .num_labels + 1 ))
611
621
self .assertEqual (outputs .logits .shape , expected_shape_logits )
612
- expected_slice_logits = torch .tensor (
613
- [
614
- [- 19.1211 , - 0.0881 , - 11.0188 ],
615
- [- 17.3641 , - 1.8045 , - 14.0229 ],
616
- [- 20.0415 , - 0.5833 , - 11.1005 ],
617
- ]
618
- ).to (torch_device )
622
+ expected_slices = Expectations (
623
+ {
624
+ (None , None ):
625
+ [
626
+ [- 19.1211 , - 0.0881 , - 11.0188 ],
627
+ [- 17.3641 , - 1.8045 , - 14.0229 ],
628
+ [- 20.0415 , - 0.5833 , - 11.1005 ],
629
+ ],
630
+ ("rocm" , (9 , 5 )):
631
+ [
632
+ [- 19.1194 , - 0.0893 , - 11.0154 ],
633
+ [- 17.3640 , - 1.8035 , - 14.0219 ],
634
+ [- 20.0461 , - 0.5837 , - 11.1060 ],
635
+ ],
636
+ }
637
+ ) # fmt: skip
638
+ expected_slice_logits = torch .tensor (expected_slices .get_expectation (), device = torch_device )
619
639
torch .testing .assert_close (outputs .logits [0 , :3 , :3 ], expected_slice_logits , rtol = 2e-4 , atol = 2e-4 )
620
640
621
641
expected_shape_boxes = torch .Size ((1 , model .config .num_queries , 4 ))
@@ -657,27 +677,65 @@ def test_inference_panoptic_segmentation_head(self):
657
677
# verify outputs
658
678
expected_shape_logits = torch .Size ((1 , model .config .num_queries , model .config .num_labels + 1 ))
659
679
self .assertEqual (outputs .logits .shape , expected_shape_logits )
660
- expected_slice_logits = torch .tensor (
661
- [
662
- [- 18.1523 , - 1.7592 , - 13.5019 ],
663
- [- 16.8866 , - 1.4139 , - 14.1025 ],
664
- [- 17.5735 , - 2.5090 , - 11.8666 ],
665
- ]
666
- ).to (torch_device )
680
+ expected_slices = Expectations (
681
+ {
682
+ (None , None ):
683
+ [
684
+ [- 18.1523 , - 1.7592 , - 13.5019 ],
685
+ [- 16.8866 , - 1.4139 , - 14.1025 ],
686
+ [- 17.5735 , - 2.5090 , - 11.8666 ],
687
+ ],
688
+ ("rocm" , (9 , 5 )):
689
+ [
690
+ [- 18.1565 , - 1.7568 , - 13.5029 ],
691
+ [- 16.8888 , - 1.4138 , - 14.1028 ],
692
+ [- 17.5709 , - 2.5080 , - 11.8654 ],
693
+ ],
694
+ }
695
+ ) # fmt: skip
696
+ expected_slice_logits = torch .tensor (expected_slices .get_expectation (), device = torch_device )
667
697
torch .testing .assert_close (outputs .logits [0 , :3 , :3 ], expected_slice_logits , rtol = 2e-4 , atol = 2e-4 )
668
698
669
699
expected_shape_boxes = torch .Size ((1 , model .config .num_queries , 4 ))
670
700
self .assertEqual (outputs .pred_boxes .shape , expected_shape_boxes )
671
- expected_slice_boxes = torch .tensor (
672
- [[0.5344 , 0.1790 , 0.9284 ], [0.4421 , 0.0571 , 0.0875 ], [0.6632 , 0.6886 , 0.1015 ]]
673
- ).to (torch_device )
701
+ expected_slices = Expectations (
702
+ {
703
+ (None , None ):
704
+ [
705
+ [0.5344 , 0.1790 , 0.9284 ],
706
+ [0.4421 , 0.0571 , 0.0875 ],
707
+ [0.6632 , 0.6886 , 0.1015 ]
708
+ ],
709
+ ("rocm" , (9 , 5 )):
710
+ [
711
+ [0.5344 , 0.1789 , 0.9285 ],
712
+ [0.4420 , 0.0572 , 0.0875 ],
713
+ [0.6630 , 0.6887 , 0.1017 ],
714
+ ],
715
+ }
716
+ ) # fmt: skip
717
+ expected_slice_boxes = torch .tensor (expected_slices .get_expectation (), device = torch_device )
674
718
torch .testing .assert_close (outputs .pred_boxes [0 , :3 , :3 ], expected_slice_boxes , rtol = 2e-4 , atol = 2e-4 )
675
719
676
720
expected_shape_masks = torch .Size ((1 , model .config .num_queries , 200 , 267 ))
677
721
self .assertEqual (outputs .pred_masks .shape , expected_shape_masks )
678
- expected_slice_masks = torch .tensor (
679
- [[- 7.8408 , - 11.0104 , - 12.1279 ], [- 12.0299 , - 16.6498 , - 17.9806 ], [- 14.8995 , - 19.9940 , - 20.5646 ]]
680
- ).to (torch_device )
722
+ expected_slices = Expectations (
723
+ {
724
+ (None , None ):
725
+ [
726
+ [- 7.8408 , - 11.0104 , - 12.1279 ],
727
+ [- 12.0299 , - 16.6498 , - 17.9806 ],
728
+ [- 14.8995 , - 19.9940 , - 20.5646 ],
729
+ ],
730
+ ("rocm" , (9 , 5 )):
731
+ [
732
+ [ - 7.7558 , - 10.8789 , - 11.9798 ],
733
+ [- 11.8882 , - 16.4330 , - 17.7452 ],
734
+ [- 14.7317 , - 19.7384 , - 20.3005 ],
735
+ ],
736
+ }
737
+ ) # fmt: skip
738
+ expected_slice_masks = torch .tensor (expected_slices .get_expectation (), device = torch_device )
681
739
torch .testing .assert_close (outputs .pred_masks [0 , 0 , :3 , :3 ], expected_slice_masks , rtol = 2e-3 , atol = 2e-3 )
682
740
683
741
# verify postprocessing
@@ -731,11 +789,21 @@ def test_inference_no_head(self):
731
789
732
790
expected_shape = torch .Size ((1 , 100 , 256 ))
733
791
assert outputs .last_hidden_state .shape == expected_shape
734
- expected_slice = torch .tensor (
735
- [
736
- [0.0622 , - 0.5142 , - 0.4034 ],
737
- [- 0.7628 , - 0.4935 , - 1.7153 ],
738
- [- 0.4751 , - 0.6386 , - 0.7818 ],
739
- ]
740
- ).to (torch_device )
792
+ expected_slices = Expectations (
793
+ {
794
+ (None , None ):
795
+ [
796
+ [0.0622 , - 0.5142 , - 0.4034 ],
797
+ [- 0.7628 , - 0.4935 , - 1.7153 ],
798
+ [- 0.4751 , - 0.6386 , - 0.7818 ],
799
+ ],
800
+ ("rocm" , (9 , 5 )):
801
+ [
802
+ [ 0.0616 , - 0.5146 , - 0.4032 ],
803
+ [- 0.7629 , - 0.4934 , - 1.7153 ],
804
+ [- 0.4768 , - 0.6403 , - 0.7826 ]
805
+ ],
806
+ }
807
+ ) # fmt: skip
808
+ expected_slice = torch .tensor (expected_slices .get_expectation (), device = torch_device )
741
809
torch .testing .assert_close (outputs .last_hidden_state [0 , :3 , :3 ], expected_slice , rtol = 1e-4 , atol = 1e-4 )
0 commit comments