Skip to content

Commit 8789c99

Browse files
zhouhao138Responsible ML Infra Team
authored andcommitted
Add fairness indicator metrics in the third_party library.
PiperOrigin-RevId: 705969893
1 parent 1545d46 commit 8789c99

File tree

3 files changed

+211
-1
lines changed

3 files changed

+211
-1
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
## Major Features and Improvements
66

7+
* Add fairness indicator metrics in the third_party library.
8+
79
## Bug Fixes and Other Changes
810

911
## Breaking Changes

fairness_indicators/example_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
results can be visualized using tools like TensorBoard.
2020
"""
2121

22+
from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
2223
from tensorflow import keras
2324
import tensorflow.compat.v1 as tf
2425
import tensorflow_model_analysis as tfma
25-
from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators # pylint: disable=unused-import
2626

2727

2828
TEXT_FEATURE = 'comment_text'
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Fairness Indicators Metrics."""
15+
16+
import collections
17+
from typing import Any, Dict, List, Optional, Sequence
18+
19+
from tensorflow_model_analysis.metrics import binary_confusion_matrices
20+
from tensorflow_model_analysis.metrics import metric_types
21+
from tensorflow_model_analysis.metrics import metric_util
22+
from tensorflow_model_analysis.proto import config_pb2
23+
24+
FAIRNESS_INDICATORS_METRICS_NAME = 'fairness_indicators_metrics'
25+
FAIRNESS_INDICATORS_SUB_METRICS = (
26+
'false_positive_rate',
27+
'false_negative_rate',
28+
'true_positive_rate',
29+
'true_negative_rate',
30+
'positive_rate',
31+
'negative_rate',
32+
'false_discovery_rate',
33+
'false_omission_rate',
34+
'precision',
35+
'recall',
36+
)
37+
38+
DEFAULT_THRESHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
39+
40+
41+
class FairnessIndicators(metric_types.Metric):
42+
"""Fairness indicators metrics."""
43+
44+
def computations_with_logging(self):
45+
"""Add streamz logging for fairness indicators."""
46+
47+
computations_fn = metric_util.merge_per_key_computations(
48+
_fairness_indicators_metrics_at_thresholds
49+
)
50+
51+
def merge_and_log_computations_fn(
52+
eval_config: Optional[config_pb2.EvalConfig] = None,
53+
# A tf metadata schema.
54+
schema: Optional[Any] = None,
55+
model_names: Optional[List[str]] = None,
56+
output_names: Optional[List[str]] = None,
57+
sub_keys: Optional[List[Optional[metric_types.SubKey]]] = None,
58+
aggregation_type: Optional[metric_types.AggregationType] = None,
59+
class_weights: Optional[Dict[int, float]] = None,
60+
example_weighted: bool = False,
61+
query_key: Optional[str] = None,
62+
**kwargs
63+
):
64+
return computations_fn(
65+
eval_config,
66+
schema,
67+
model_names,
68+
output_names,
69+
sub_keys,
70+
aggregation_type,
71+
class_weights,
72+
example_weighted,
73+
query_key,
74+
**kwargs
75+
)
76+
77+
return merge_and_log_computations_fn
78+
79+
def __init__(
80+
self,
81+
thresholds: Sequence[float] = DEFAULT_THRESHOLDS,
82+
name: str = FAIRNESS_INDICATORS_METRICS_NAME,
83+
):
84+
"""Initializes fairness indicators metrics.
85+
86+
Args:
87+
thresholds: Thresholds to use for fairness metrics.
88+
name: Metric name.
89+
"""
90+
super().__init__(
91+
self.computations_with_logging(), thresholds=thresholds, name=name
92+
)
93+
94+
95+
def calculate_digits(thresholds):
96+
digits = [len(str(t)) - 2 for t in thresholds]
97+
return max(max(digits), 1)
98+
99+
100+
def _fairness_indicators_metrics_at_thresholds(
101+
thresholds: List[float],
102+
name: str = FAIRNESS_INDICATORS_METRICS_NAME,
103+
eval_config: Optional[config_pb2.EvalConfig] = None,
104+
model_name: str = '',
105+
output_name: str = '',
106+
aggregation_type: Optional[metric_types.AggregationType] = None,
107+
sub_key: Optional[metric_types.SubKey] = None,
108+
class_weights: Optional[Dict[int, float]] = None,
109+
example_weighted: bool = False,
110+
) -> metric_types.MetricComputations:
111+
"""Returns computations for fairness metrics at thresholds."""
112+
metric_key_by_name_by_threshold = collections.defaultdict(dict)
113+
keys = []
114+
digits_num = calculate_digits(thresholds)
115+
for t in thresholds:
116+
for m in FAIRNESS_INDICATORS_SUB_METRICS:
117+
key = metric_types.MetricKey(
118+
name='%s/%s@%.*f'
119+
% (
120+
name,
121+
m,
122+
digits_num,
123+
t,
124+
), # e.g. "fairness_indicators_metrics/positive_rate@0.5"
125+
model_name=model_name,
126+
output_name=output_name,
127+
sub_key=sub_key,
128+
example_weighted=example_weighted,
129+
)
130+
keys.append(key)
131+
metric_key_by_name_by_threshold[t][m] = key
132+
133+
# Make sure matrices are calculated.
134+
computations = binary_confusion_matrices.binary_confusion_matrices(
135+
eval_config=eval_config,
136+
model_name=model_name,
137+
output_name=output_name,
138+
sub_key=sub_key,
139+
aggregation_type=aggregation_type,
140+
class_weights=class_weights,
141+
example_weighted=example_weighted,
142+
thresholds=thresholds,
143+
)
144+
confusion_matrices_key = computations[-1].keys[-1]
145+
146+
def result(
147+
metrics: Dict[metric_types.MetricKey, Any],
148+
) -> Dict[metric_types.MetricKey, Any]:
149+
"""Returns fairness metrics values."""
150+
metric = metrics[confusion_matrices_key]
151+
output = {}
152+
153+
for i, threshold in enumerate(thresholds):
154+
num_positives = metric.tp[i] + metric.fn[i]
155+
num_negatives = metric.tn[i] + metric.fp[i]
156+
157+
tpr = metric.tp[i] / (num_positives or float('nan'))
158+
tnr = metric.tn[i] / (num_negatives or float('nan'))
159+
fpr = metric.fp[i] / (num_negatives or float('nan'))
160+
fnr = metric.fn[i] / (num_positives or float('nan'))
161+
pr = (metric.tp[i] + metric.fp[i]) / (
162+
(num_positives + num_negatives) or float('nan')
163+
)
164+
nr = (metric.tn[i] + metric.fn[i]) / (
165+
(num_positives + num_negatives) or float('nan')
166+
)
167+
precision = metric.tp[i] / ((metric.tp[i] + metric.fp[i]) or float('nan'))
168+
recall = metric.tp[i] / ((metric.tp[i] + metric.fn[i]) or float('nan'))
169+
170+
fdr = metric.fp[i] / ((metric.fp[i] + metric.tp[i]) or float('nan'))
171+
fomr = metric.fn[i] / ((metric.fn[i] + metric.tn[i]) or float('nan'))
172+
173+
output[
174+
metric_key_by_name_by_threshold[threshold]['false_positive_rate']
175+
] = fpr
176+
output[
177+
metric_key_by_name_by_threshold[threshold]['false_negative_rate']
178+
] = fnr
179+
output[
180+
metric_key_by_name_by_threshold[threshold]['true_positive_rate']
181+
] = tpr
182+
output[
183+
metric_key_by_name_by_threshold[threshold]['true_negative_rate']
184+
] = tnr
185+
output[metric_key_by_name_by_threshold[threshold]['positive_rate']] = pr
186+
output[metric_key_by_name_by_threshold[threshold]['negative_rate']] = nr
187+
output[
188+
metric_key_by_name_by_threshold[threshold]['false_discovery_rate']
189+
] = fdr
190+
output[
191+
metric_key_by_name_by_threshold[threshold]['false_omission_rate']
192+
] = fomr
193+
output[metric_key_by_name_by_threshold[threshold]['precision']] = (
194+
precision
195+
)
196+
output[metric_key_by_name_by_threshold[threshold]['recall']] = recall
197+
198+
return output
199+
200+
derived_computation = metric_types.DerivedMetricComputation(
201+
keys=keys, result=result
202+
)
203+
204+
computations.append(derived_computation)
205+
return computations
206+
207+
208+
metric_types.register_metric(FairnessIndicators)

0 commit comments

Comments
 (0)