1010
1111class ListMLELossTest (testing .TestCase , parameterized .TestCase ):
1212 def setUp (self ):
13- self .unbatched_scores = ops .array ([1.0 , 3.0 , 2.0 , 4.0 , 0.8 ])
14- self .unbatched_labels = ops .array ([1.0 , 0.0 , 1.0 , 3.0 , 2.0 ])
13+ self .unbatched_scores = ops .array ([1.0 , 3.0 , 2.0 , 4.0 , 0.8 ], dtype = "float32" )
14+ self .unbatched_labels = ops .array ([1.0 , 0.0 , 1.0 , 3.0 , 2.0 ], dtype = "float32" )
1515
1616 self .batched_scores = ops .array (
17- [[1.0 , 3.0 , 2.0 , 4.0 , 0.8 ], [1.0 , 1.8 , 2.0 , 3.0 , 2.0 ]]
17+ [[1.0 , 3.0 , 2.0 , 4.0 , 0.8 ], [1.0 , 1.8 , 2.0 , 3.0 , 2.0 ]], dtype = "float32"
1818 )
1919 self .batched_labels = ops .array (
20- [[1.0 , 0.0 , 1.0 , 3.0 , 2.0 ], [0.0 , 1.0 , 2.0 , 3.0 , 1.5 ]]
20+ [[1.0 , 0.0 , 1.0 , 3.0 , 2.0 ], [0.0 , 1.0 , 2.0 , 3.0 , 1.5 ]], dtype = "float32"
2121 )
22- self .expected_output = ops .array ([6.865693 , 3.088192 ])
22+ self .expected_output = ops .array ([6.865693 , 3.088192 ], dtype = "float32" )
2323
2424 def test_unbatched_input (self ):
2525 loss = ListMLELoss (reduction = "none" )
@@ -43,7 +43,6 @@ def test_temperature(self):
4343 output_temp = loss_temp (
4444 y_true = self .batched_labels , y_pred = self .batched_scores
4545 )
46-
4746 self .assertAllClose (
4847 output_temp ,
4948 [10.969891 , 2.1283305 ],
@@ -60,7 +59,6 @@ def test_invalid_input_rank(self):
6059 def test_loss_reduction (self ):
6160 loss = ListMLELoss (reduction = "sum_over_batch_size" )
6261 output = loss (y_true = self .batched_labels , y_pred = self .batched_scores )
63-
6462 self .assertAlmostEqual (
6563 ops .convert_to_numpy (output ), 4.9769425 , places = 5
6664 )
0 commit comments