2525import dpctl
2626import unittest
2727
28+
2829class TestGetNumPlatforms (unittest .TestCase ):
2930 @unittest .skipIf (not dpctl .has_sycl_platforms (),
3031 "No SYCL platforms available" )
3132 def test_dpctl_get_num_platforms (self ):
3233 if (dpctl .has_sycl_platforms ):
3334 self .assertGreaterEqual (dpctl .get_num_platforms (), 1 )
3435
36+
3537@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
3638class TestDumpMethods (unittest .TestCase ):
3739 def test_dpctl_dump (self ):
@@ -47,6 +49,7 @@ def test_dpctl_dump_device_info (self):
4749 except Exception :
4850 self .fail ("Encountered an exception inside dump_device_info()." )
4951
52+
5053@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
5154class TestIsInDeviceContext (unittest .TestCase ):
5255
@@ -65,6 +68,35 @@ def test_is_in_device_context_inside_nested_device_ctxt (self):
6568 self .assertTrue (dpctl .is_in_device_context ())
6669 self .assertFalse (dpctl .is_in_device_context ())
6770
71+
72+ @unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
73+ class TestIsInDeviceContext (unittest .TestCase ):
74+
75+ def test_get_current_device_type_outside_device_ctxt (self ):
76+ self .assertEqual (dpctl .get_current_device_type (), None )
77+
78+ def test_get_current_device_type_inside_device_ctxt (self ):
79+ self .assertEqual (dpctl .get_current_device_type (), None )
80+
81+ with dpctl .device_context (dpctl .device_type .gpu ):
82+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
83+
84+ self .assertEqual (dpctl .get_current_device_type (), None )
85+
86+ @unittest .skipIf (not dpctl .has_cpu_queues (), "No CPU platforms available" )
87+ def test_get_current_device_type_inside_nested_device_ctxt (self ):
88+ self .assertEqual (dpctl .get_current_device_type (), None )
89+
90+ with dpctl .device_context (dpctl .device_type .cpu ):
91+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
92+
93+ with dpctl .device_context (dpctl .device_type .gpu ):
94+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
95+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
96+
97+ self .assertEqual (dpctl .get_current_device_type (), None )
98+
99+
68100@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
69101class TestGetCurrentQueueInMultipleThreads (unittest .TestCase ):
70102
@@ -96,5 +128,6 @@ def SessionThread (self):
96128 Session1 .start ()
97129 Session2 .start ()
98130
131+
99132if __name__ == '__main__' :
100133 unittest .main ()
0 commit comments