25
25
from tensorboard_plugin_fairness_indicators import plugin
26
26
from tensorboard_plugin_fairness_indicators import summary_v2
27
27
import six
28
- import tensorflow .compat .v1 as tf
29
- import tensorflow .compat .v2 as tf2
28
+ import tensorflow as tf2
29
+ from tensorflow .keras import layers
30
+ from tensorflow .keras import models
30
31
import tensorflow_model_analysis as tfma
31
- from tensorflow_model_analysis .eval_saved_model .example_trainers import linear_classifier
32
32
from werkzeug import test as werkzeug_test
33
33
from werkzeug import wrappers
34
34
35
35
from tensorboard .backend import application
36
36
from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
37
37
from tensorboard .plugins import base_plugin
38
38
39
- tf .enable_eager_execution ()
39
+ Sequential = models .Sequential
40
+ Dense = layers .Dense
41
+
40
42
tf = tf2
41
43
42
44
45
+ # Define keras based linear classifier.
46
+ def create_linear_classifier (model_dir ):
47
+
48
+ model = Sequential ([Dense (1 , activation = "sigmoid" , input_shape = (2 ,))])
49
+ model .compile (
50
+ optimizer = "adam" , loss = "binary_crossentropy" , metrics = ["accuracy" ]
51
+ )
52
+ # Convert the Sequential model to a tf.Module before saving
53
+ model = tf .keras .models .Model (inputs = model .inputs , outputs = model .outputs )
54
+ tf .saved_model .save (model , model_dir )
55
+ return model
56
+
57
+
43
58
class PluginTest (tf .test .TestCase ):
44
59
"""Tests for Fairness Indicators plugin server."""
45
60
@@ -74,19 +89,19 @@ def tearDown(self):
74
89
super (PluginTest , self ).tearDown ()
75
90
shutil .rmtree (self ._log_dir , ignore_errors = True )
76
91
77
- def _exportEvalSavedModel (self , classifier ):
92
+ def _export_eval_saved_model (self ):
93
+ """Export the evaluation saved model."""
78
94
temp_eval_export_dir = os .path .join (self .get_temp_dir (), "eval_export_dir" )
79
- _ , eval_export_dir = classifier (None , temp_eval_export_dir )
80
- return eval_export_dir
95
+ return create_linear_classifier (temp_eval_export_dir )
81
96
82
- def _writeTFExamplesToTFRecords (self , examples ):
97
+ def _write_tf_examples_to_tfrecords (self , examples ):
83
98
data_location = os .path .join (self .get_temp_dir (), "input_data.rio" )
84
99
with tf .io .TFRecordWriter (data_location ) as writer :
85
100
for example in examples :
86
101
writer .write (example .SerializeToString ())
87
102
return data_location
88
103
89
- def _makeExample (self , age , language , label ):
104
+ def _make_tf_example (self , age , language , label ):
90
105
example = tf .train .Example ()
91
106
example .features .feature ["age" ].float_list .value [:] = [age ]
92
107
example .features .feature ["language" ].bytes_list .value [:] = [
@@ -112,14 +127,14 @@ def testRoutes(self):
112
127
"foo" : "" .encode ("utf-8" )
113
128
}},
114
129
)
115
- def testIsActive (self , get_random_stub ):
130
+ def testIsActive (self ):
116
131
self .assertTrue (self ._plugin .is_active ())
117
132
118
133
@mock .patch .object (
119
134
event_multiplexer .EventMultiplexer ,
120
135
"PluginRunToTagToContent" ,
121
136
return_value = {})
122
- def testIsInactive (self , get_random_stub ):
137
+ def testIsInactive (self ):
123
138
self .assertFalse (self ._plugin .is_active ())
124
139
125
140
def testIndexJsRoute (self ):
@@ -134,16 +149,15 @@ def testVulcanizedTemplateRoute(self):
134
149
self .assertEqual (200 , response .status_code )
135
150
136
151
def testGetEvalResultsRoute (self ):
137
- model_location = self ._exportEvalSavedModel (
138
- linear_classifier .simple_linear_classifier )
152
+ model_location = self ._export_eval_saved_model () # Call the method
139
153
examples = [
140
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
141
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
142
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
143
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
144
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
154
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
155
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
156
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
157
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
158
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
145
159
]
146
- data_location = self ._writeTFExamplesToTFRecords (examples )
160
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
147
161
_ = tfma .run_model_analysis (
148
162
eval_shared_model = tfma .default_eval_shared_model (
149
163
eval_saved_model_path = model_location , example_weight_key = "age" ),
@@ -155,32 +169,36 @@ def testGetEvalResultsRoute(self):
155
169
self .assertEqual (200 , response .status_code )
156
170
157
171
def testGetEvalResultsFromURLRoute (self ):
158
- model_location = self ._exportEvalSavedModel (
159
- linear_classifier .simple_linear_classifier )
172
+ model_location = self ._export_eval_saved_model () # Call the method
160
173
examples = [
161
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
162
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
163
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
164
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
165
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
174
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
175
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
176
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
177
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
178
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
166
179
]
167
- data_location = self ._writeTFExamplesToTFRecords (examples )
180
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
168
181
_ = tfma .run_model_analysis (
169
182
eval_shared_model = tfma .default_eval_shared_model (
170
183
eval_saved_model_path = model_location , example_weight_key = "age" ),
171
184
data_location = data_location ,
172
185
output_path = self ._eval_result_output_dir )
173
186
174
187
response = self ._server .get (
175
- "/data/plugin/fairness_indicators/" +
176
- "get_evaluation_result_from_remote_path?evaluation_output_path=" +
177
- os .path .join (self ._eval_result_output_dir , tfma .METRICS_KEY ))
188
+ "/data/plugin/fairness_indicators/"
189
+ + "get_evaluation_result_from_remote_path?evaluation_output_path="
190
+ + self ._eval_result_output_dir
191
+ )
178
192
self .assertEqual (200 , response .status_code )
179
193
180
- def testGetOutputFileFormat (self ):
181
- self .assertEqual ("" , self ._plugin ._get_output_file_format ("abc_path" ))
182
- self .assertEqual ("tfrecord" ,
183
- self ._plugin ._get_output_file_format ("abc_path.tfrecord" ))
194
+ def test_get_output_file_format (self ):
195
+ evaluation_output_path = os .path .join (
196
+ self ._eval_result_output_dir , "eval_result.tfrecord"
197
+ )
198
+ self .assertEqual (
199
+ self ._plugin ._get_output_file_format (evaluation_output_path ),
200
+ "tfrecord" ,
201
+ )
184
202
185
203
186
204
if __name__ == "__main__" :
0 commit comments