21
21
22
22
from typing import Any
23
23
24
- from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
25
- from tensorflow import keras
26
24
import tensorflow .compat .v1 as tf
27
25
import tensorflow_model_analysis as tfma
26
+ from tensorflow import keras
28
27
29
-
30
- TEXT_FEATURE = 'comment_text'
31
- LABEL = 'toxicity'
32
- SLICE = 'slice'
28
+ TEXT_FEATURE = "comment_text"
29
+ LABEL = "toxicity"
30
+ SLICE = "slice"
33
31
FEATURE_MAP = {
34
32
LABEL : tf .io .FixedLenFeature ([], tf .float32 ),
35
33
TEXT_FEATURE : tf .io .FixedLenFeature ([], tf .string ),
38
36
39
37
40
38
class ExampleParser (keras .layers .Layer ):
41
- """A Keras layer that parses the tf.Example."""
39
+ """A Keras layer that parses the tf.Example."""
40
+
41
+ def __init__ (self , input_feature_key ):
42
+ self ._input_feature_key = input_feature_key
43
+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
44
+ super ().__init__ ()
42
45
43
- def __init__ (self , input_feature_key ):
44
- self ._input_feature_key = input_feature_key
45
- self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
46
- super ().__init__ ()
46
+ def compute_output_shape (self , input_shape : Any ):
47
+ return [1 , 1 ]
47
48
48
- def compute_output_shape (self , input_shape : Any ):
49
- return [1 , 1 ]
49
+ def call (self , serialized_examples ):
50
+ def get_feature (serialized_example ):
51
+ parsed_example = tf .io .parse_single_example (
52
+ serialized_example , features = FEATURE_MAP
53
+ )
54
+ return parsed_example [self ._input_feature_key ]
50
55
51
- def call (self , serialized_examples ):
52
- def get_feature (serialized_example ):
53
- parsed_example = tf .io .parse_single_example (
54
- serialized_example , features = FEATURE_MAP
55
- )
56
- return parsed_example [self ._input_feature_key ]
57
- serialized_examples = tf .cast (serialized_examples , tf .string )
58
- return tf .map_fn (get_feature , serialized_examples )
56
+ serialized_examples = tf .cast (serialized_examples , tf .string )
57
+ return tf .map_fn (get_feature , serialized_examples )
59
58
60
59
61
60
class Reshaper (keras .layers .Layer ):
62
- """A Keras layer that reshapes the input."""
61
+ """A Keras layer that reshapes the input."""
63
62
64
- def call (self , inputs ):
65
- return tf .reshape (inputs , (1 , 32 ))
63
+ def call (self , inputs ):
64
+ return tf .reshape (inputs , (1 , 32 ))
66
65
67
66
68
67
class Caster (keras .layers .Layer ):
69
- """A Keras layer that reshapes the input."""
68
+ """A Keras layer that reshapes the input."""
70
69
71
- def call (self , inputs ):
72
- return tf .cast (inputs , tf .float32 )
70
+ def call (self , inputs ):
71
+ return tf .cast (inputs , tf .float32 )
73
72
74
73
75
74
def get_example_model (input_feature_key : str ):
76
- """Returns a Keras model for testing."""
77
- parser = ExampleParser (input_feature_key )
78
- text_vectorization = keras .layers .TextVectorization (
79
- max_tokens = 32 ,
80
- output_mode = ' int' ,
81
- output_sequence_length = 32 ,
82
- )
83
- text_vectorization .adapt (
84
- [ ' nontoxic' , ' toxic comment' , ' test comment' , ' abc' , ' abcdef' , ' random' ]
85
- )
86
- dense1 = keras .layers .Dense (
87
- 32 ,
88
- activation = None ,
89
- use_bias = True ,
90
- kernel_initializer = ' glorot_uniform' ,
91
- bias_initializer = ' zeros' ,
92
- )
93
- dense2 = keras .layers .Dense (
94
- 1 ,
95
- activation = None ,
96
- use_bias = False ,
97
- kernel_initializer = ' glorot_uniform' ,
98
- bias_initializer = ' zeros' ,
99
- )
100
-
101
- inputs = tf .keras .Input (shape = (), dtype = tf .string )
102
- parsed_example = parser (inputs )
103
- text_vector = text_vectorization (parsed_example )
104
- text_vector = Reshaper ()(text_vector )
105
- text_vector = Caster ()(text_vector )
106
- output1 = dense1 (text_vector )
107
- output2 = dense2 (output1 )
108
- return tf .keras .Model (inputs = inputs , outputs = output2 )
75
+ """Returns a Keras model for testing."""
76
+ parser = ExampleParser (input_feature_key )
77
+ text_vectorization = keras .layers .TextVectorization (
78
+ max_tokens = 32 ,
79
+ output_mode = " int" ,
80
+ output_sequence_length = 32 ,
81
+ )
82
+ text_vectorization .adapt (
83
+ [ " nontoxic" , " toxic comment" , " test comment" , " abc" , " abcdef" , " random" ]
84
+ )
85
+ dense1 = keras .layers .Dense (
86
+ 32 ,
87
+ activation = None ,
88
+ use_bias = True ,
89
+ kernel_initializer = " glorot_uniform" ,
90
+ bias_initializer = " zeros" ,
91
+ )
92
+ dense2 = keras .layers .Dense (
93
+ 1 ,
94
+ activation = None ,
95
+ use_bias = False ,
96
+ kernel_initializer = " glorot_uniform" ,
97
+ bias_initializer = " zeros" ,
98
+ )
99
+
100
+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
101
+ parsed_example = parser (inputs )
102
+ text_vector = text_vectorization (parsed_example )
103
+ text_vector = Reshaper ()(text_vector )
104
+ text_vector = Caster ()(text_vector )
105
+ output1 = dense1 (text_vector )
106
+ output2 = dense2 (output1 )
107
+ return tf .keras .Model (inputs = inputs , outputs = output2 )
109
108
110
109
111
110
def evaluate_model (
@@ -114,23 +113,23 @@ def evaluate_model(
114
113
tfma_eval_result_path ,
115
114
eval_config ,
116
115
):
117
- """Evaluate Model using Tensorflow Model Analysis.
118
-
119
- Args:
120
- classifier_model_path: Trained classifier model to be evaluted.
121
- validate_tf_file_path: File containing validation TFRecordDataset .
122
- tfma_eval_result_path: Path to export tfma-related eval path .
123
- eval_config: tfma eval_config .
124
- """
125
-
126
- eval_shared_model = tfma .default_eval_shared_model (
127
- eval_saved_model_path = classifier_model_path , eval_config = eval_config
128
- )
129
-
130
- # Run the fairness evaluation.
131
- tfma .run_model_analysis (
132
- eval_shared_model = eval_shared_model ,
133
- data_location = validate_tf_file_path ,
134
- output_path = tfma_eval_result_path ,
135
- eval_config = eval_config ,
136
- )
116
+ """Evaluate Model using Tensorflow Model Analysis.
117
+
118
+ Args:
119
+ ----
120
+ classifier_model_path: Trained classifier model to be evaluted .
121
+ validate_tf_file_path: File containing validation TFRecordDataset .
122
+ tfma_eval_result_path: Path to export tfma-related eval path .
123
+ eval_config: tfma eval_config.
124
+ """
125
+ eval_shared_model = tfma .default_eval_shared_model (
126
+ eval_saved_model_path = classifier_model_path , eval_config = eval_config
127
+ )
128
+
129
+ # Run the fairness evaluation.
130
+ tfma .run_model_analysis (
131
+ eval_shared_model = eval_shared_model ,
132
+ data_location = validate_tf_file_path ,
133
+ output_path = tfma_eval_result_path ,
134
+ eval_config = eval_config ,
135
+ )
0 commit comments