|
10 | 10 |
|
11 | 11 | class ListMLELossTest(testing.TestCase, parameterized.TestCase): |
12 | 12 | def setUp(self): |
13 | | - self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32") |
14 | | - self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32") |
| 13 | + self.unbatched_scores = ops.array( |
| 14 | + [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" |
| 15 | + ) |
| 16 | + self.unbatched_labels = ops.array( |
| 17 | + [1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32" |
| 18 | + ) |
15 | 19 |
|
16 | 20 | self.batched_scores = ops.array( |
17 | | - [[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], dtype="float32" |
| 21 | + [[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], |
| 22 | + dtype="float32", |
18 | 23 | ) |
19 | 24 | self.batched_labels = ops.array( |
20 | | - [[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], dtype="float32" |
| 25 | + [[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], |
| 26 | + dtype="float32", |
21 | 27 | ) |
22 | 28 | self.expected_output = ops.array([6.865693, 3.088192], dtype="float32") |
23 | 29 |
|
|
0 commit comments