|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Custom metrics: MeanRank, MaxRank. |
| 15 | +
|
| 16 | +Related metrics: |
| 17 | + MinRank: Is zero as soon as accuracy is non-zero. |
| 18 | + keras.metrics.TopKCategoricalAccuracy: How often the correct class |
| 19 | + is in the top K predictions (how often is rank less than K). |
| 20 | +""" |
| 21 | +from typing import Any, Optional |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +import keras |
| 25 | +from keras.metrics import Metric, categorical_accuracy |
| 26 | + |
| 27 | +from scaaml.utils import requires |
| 28 | + |
| 29 | + |
| 30 | +class SignificanceTest(Metric): # type: ignore[no-any-unimported,misc] |
| 31 | + """Calculates the probability that a random guess would get the same |
| 32 | + accuracy. Probability is in the interval [0, 1] (impossible to always). By |
| 33 | + convention one rejects the null hypothesis at a given p-value (say 0.005 if |
| 34 | + we want to be sure). |
| 35 | +
|
| 36 | + The method `SignificanceTest.result` requires SciPy to be installed. We |
| 37 | + also mark the `__init__` so that users do not waste time without a chance |
| 38 | + to get the result. |
| 39 | +
|
| 40 | + Args: |
| 41 | + name: (Optional) String name of the metric instance. |
| 42 | + dtype: (Optional) Data type of the metric result. |
| 43 | +
|
| 44 | + Standalone usage: |
| 45 | +
|
| 46 | + ```python |
| 47 | + >>> m = SignificanceTest() |
| 48 | + >>> m.update_state([[0., 1.], [1., 0.]], [[0.1, 0.9], [0.6, 0.4]]) |
| 49 | + >>> m.result().numpy() |
| 50 | + 0.25 |
| 51 | + ``` |
| 52 | +
|
| 53 | + Usage with `compile()` API: |
| 54 | +
|
| 55 | + ```python |
| 56 | + model.compile(optimizer="sgd", |
| 57 | + loss="mse", |
| 58 | + metrics=[SignificanceTest()]) |
| 59 | + ``` |
| 60 | + """ |
| 61 | + |
| 62 | + @requires("scipy") |
| 63 | + def __init__(self, name: str = "SignificanceTest", **kwargs: Any) -> None: |
| 64 | + super().__init__(name=name, **kwargs) |
| 65 | + self.correct = self.add_weight(name="correct", initializer="zeros") |
| 66 | + self.possibilities = self.add_weight(name="possibilities", |
| 67 | + initializer="zeros") |
| 68 | + self.seen = self.add_weight(name="seen", initializer="zeros") |
| 69 | + |
| 70 | + def update_state(self, |
| 71 | + y_true: Any, |
| 72 | + y_pred: Any, |
| 73 | + sample_weight: Optional[Any] = None) -> None: |
| 74 | + """Update the state. |
| 75 | +
|
| 76 | + Args: |
| 77 | + y_true (batch of one-hot): One-hot ground truth values. |
| 78 | + y_pred (batch of one-hot): The prediction values. |
| 79 | + sample_weight (Optional weights): Does not make sense, as we count |
| 80 | + maximum. |
| 81 | + """ |
| 82 | + del sample_weight # unused |
| 83 | + |
| 84 | + # Make into tensors. |
| 85 | + y_true = np.array(y_true, dtype=np.float32) |
| 86 | + y_pred = np.array(y_pred, dtype=np.float32) |
| 87 | + |
| 88 | + # Update the number of seen examples. |
| 89 | + self.seen.assign(self.seen + y_true.shape[0]) |
| 90 | + |
| 91 | + # Update the number of correctly predicted examples. |
| 92 | + correct_now = keras.ops.sum(categorical_accuracy(y_true, y_pred)) |
| 93 | + self.correct.assign(self.correct + correct_now) |
| 94 | + |
| 95 | + # Update the number of possibilities. |
| 96 | + self.possibilities.assign(y_true.shape[-1]) |
| 97 | + |
| 98 | + def result(self) -> Any: |
| 99 | + """Return the result.""" |
| 100 | + # SciPy is an optional dependency. |
| 101 | + import scipy # pylint: disable=import-outside-toplevel |
| 102 | + |
| 103 | + # Binomial distribution(n, p) -- how many successes out of n trials, |
| 104 | + # each succeeds with probability p independently on others. |
| 105 | + # scipy.stats.binom.cdf(k, n, p) -- probability there are <= k |
| 106 | + # successes. |
| 107 | + # We want to answer what is the probability that a random guess has at |
| 108 | + # least self.correct or more successes. Which is the same as 1 - |
| 109 | + # probability that it has at most k-1 successes. |
| 110 | + k = self.correct.numpy() |
| 111 | + n = self.seen.numpy() |
| 112 | + possibilities = self.possibilities.numpy() |
| 113 | + return 1 - scipy.stats.binom.cdf( |
| 114 | + k - 1, |
| 115 | + n, |
| 116 | + 1 / possibilities, |
| 117 | + ) |
| 118 | + |
| 119 | + def reset_state(self) -> None: |
| 120 | + """Reset the state for new measurement.""" |
| 121 | + # The state of the metric will be reset at the start of each epoch. |
| 122 | + self.seen.assign(0.0) |
| 123 | + self.correct.assign(0.0) |
| 124 | + self.possibilities.assign(0.0) |
0 commit comments