21
21
22
22
from typing import Any
23
23
24
- from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
25
- from tensorflow import keras
26
24
import tensorflow .compat .v1 as tf
27
25
import tensorflow_model_analysis as tfma
26
+ from tensorflow import keras
28
27
28
+ from fairness_indicators import fairness_indicators_metrics # noqa: F401
29
29
30
- TEXT_FEATURE = ' comment_text'
31
- LABEL = ' toxicity'
32
- SLICE = ' slice'
30
+ TEXT_FEATURE = " comment_text"
31
+ LABEL = " toxicity"
32
+ SLICE = " slice"
33
33
FEATURE_MAP = {
34
34
LABEL : tf .io .FixedLenFeature ([], tf .float32 ),
35
35
TEXT_FEATURE : tf .io .FixedLenFeature ([], tf .string ),
38
38
39
39
40
40
class ExampleParser (keras .layers .Layer ):
41
- """A Keras layer that parses the tf.Example."""
41
+ """A Keras layer that parses the tf.Example."""
42
+
43
+ def __init__ (self , input_feature_key ):
44
+ self ._input_feature_key = input_feature_key
45
+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
46
+ super ().__init__ ()
42
47
43
- def __init__ (self , input_feature_key ):
44
- self ._input_feature_key = input_feature_key
45
- self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
46
- super ().__init__ ()
48
+ def compute_output_shape (self , input_shape : Any ):
49
+ return [1 , 1 ]
47
50
48
- def compute_output_shape (self , input_shape : Any ):
49
- return [1 , 1 ]
51
+ def call (self , serialized_examples ):
52
+ def get_feature (serialized_example ):
53
+ parsed_example = tf .io .parse_single_example (
54
+ serialized_example , features = FEATURE_MAP
55
+ )
56
+ return parsed_example [self ._input_feature_key ]
50
57
51
- def call (self , serialized_examples ):
52
- def get_feature (serialized_example ):
53
- parsed_example = tf .io .parse_single_example (
54
- serialized_example , features = FEATURE_MAP
55
- )
56
- return parsed_example [self ._input_feature_key ]
57
- serialized_examples = tf .cast (serialized_examples , tf .string )
58
- return tf .map_fn (get_feature , serialized_examples )
58
+ serialized_examples = tf .cast (serialized_examples , tf .string )
59
+ return tf .map_fn (get_feature , serialized_examples )
59
60
60
61
61
62
class Reshaper (keras .layers .Layer ):
62
- """A Keras layer that reshapes the input."""
63
+ """A Keras layer that reshapes the input."""
63
64
64
- def call (self , inputs ):
65
- return tf .reshape (inputs , (1 , 32 ))
65
+ def call (self , inputs ):
66
+ return tf .reshape (inputs , (1 , 32 ))
66
67
67
68
68
69
class Caster (keras .layers .Layer ):
69
- """A Keras layer that reshapes the input."""
70
+ """A Keras layer that reshapes the input."""
70
71
71
- def call (self , inputs ):
72
- return tf .cast (inputs , tf .float32 )
72
+ def call (self , inputs ):
73
+ return tf .cast (inputs , tf .float32 )
73
74
74
75
75
76
def get_example_model (input_feature_key : str ):
76
- """Returns a Keras model for testing."""
77
- parser = ExampleParser (input_feature_key )
78
- text_vectorization = keras .layers .TextVectorization (
79
- max_tokens = 32 ,
80
- output_mode = ' int' ,
81
- output_sequence_length = 32 ,
82
- )
83
- text_vectorization .adapt (
84
- [ ' nontoxic' , ' toxic comment' , ' test comment' , ' abc' , ' abcdef' , ' random' ]
85
- )
86
- dense1 = keras .layers .Dense (
87
- 32 ,
88
- activation = None ,
89
- use_bias = True ,
90
- kernel_initializer = ' glorot_uniform' ,
91
- bias_initializer = ' zeros' ,
92
- )
93
- dense2 = keras .layers .Dense (
94
- 1 ,
95
- activation = None ,
96
- use_bias = False ,
97
- kernel_initializer = ' glorot_uniform' ,
98
- bias_initializer = ' zeros' ,
99
- )
100
-
101
- inputs = tf .keras .Input (shape = (), dtype = tf .string )
102
- parsed_example = parser (inputs )
103
- text_vector = text_vectorization (parsed_example )
104
- text_vector = Reshaper ()(text_vector )
105
- text_vector = Caster ()(text_vector )
106
- output1 = dense1 (text_vector )
107
- output2 = dense2 (output1 )
108
- return tf .keras .Model (inputs = inputs , outputs = output2 )
77
+ """Returns a Keras model for testing."""
78
+ parser = ExampleParser (input_feature_key )
79
+ text_vectorization = keras .layers .TextVectorization (
80
+ max_tokens = 32 ,
81
+ output_mode = " int" ,
82
+ output_sequence_length = 32 ,
83
+ )
84
+ text_vectorization .adapt (
85
+ [ " nontoxic" , " toxic comment" , " test comment" , " abc" , " abcdef" , " random" ]
86
+ )
87
+ dense1 = keras .layers .Dense (
88
+ 32 ,
89
+ activation = None ,
90
+ use_bias = True ,
91
+ kernel_initializer = " glorot_uniform" ,
92
+ bias_initializer = " zeros" ,
93
+ )
94
+ dense2 = keras .layers .Dense (
95
+ 1 ,
96
+ activation = None ,
97
+ use_bias = False ,
98
+ kernel_initializer = " glorot_uniform" ,
99
+ bias_initializer = " zeros" ,
100
+ )
101
+
102
+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
103
+ parsed_example = parser (inputs )
104
+ text_vector = text_vectorization (parsed_example )
105
+ text_vector = Reshaper ()(text_vector )
106
+ text_vector = Caster ()(text_vector )
107
+ output1 = dense1 (text_vector )
108
+ output2 = dense2 (output1 )
109
+ return tf .keras .Model (inputs = inputs , outputs = output2 )
109
110
110
111
111
112
def evaluate_model (
@@ -114,23 +115,23 @@ def evaluate_model(
114
115
tfma_eval_result_path ,
115
116
eval_config ,
116
117
):
117
- """Evaluate Model using Tensorflow Model Analysis.
118
-
119
- Args:
120
- classifier_model_path: Trained classifier model to be evaluted.
121
- validate_tf_file_path: File containing validation TFRecordDataset .
122
- tfma_eval_result_path: Path to export tfma-related eval path .
123
- eval_config: tfma eval_config .
124
- """
125
-
126
- eval_shared_model = tfma .default_eval_shared_model (
127
- eval_saved_model_path = classifier_model_path , eval_config = eval_config
128
- )
129
-
130
- # Run the fairness evaluation.
131
- tfma .run_model_analysis (
132
- eval_shared_model = eval_shared_model ,
133
- data_location = validate_tf_file_path ,
134
- output_path = tfma_eval_result_path ,
135
- eval_config = eval_config ,
136
- )
118
+ """Evaluate Model using Tensorflow Model Analysis.
119
+
120
+ Args:
121
+ ----
122
+ classifier_model_path: Trained classifier model to be evaluted .
123
+ validate_tf_file_path: File containing validation TFRecordDataset .
124
+ tfma_eval_result_path: Path to export tfma-related eval path .
125
+ eval_config: tfma eval_config.
126
+ """
127
+ eval_shared_model = tfma .default_eval_shared_model (
128
+ eval_saved_model_path = classifier_model_path , eval_config = eval_config
129
+ )
130
+
131
+ # Run the fairness evaluation.
132
+ tfma .run_model_analysis (
133
+ eval_shared_model = eval_shared_model ,
134
+ data_location = validate_tf_file_path ,
135
+ output_path = tfma_eval_result_path ,
136
+ eval_config = eval_config ,
137
+ )
0 commit comments