1
+
1
2
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3
#
3
4
# Licensed under the Apache License, Version 2.0 (the "License");
19
20
results can be visualized using tools like TensorBoard.
20
21
"""
21
22
23
+ from typing import Any
22
24
from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
23
25
from tensorflow import keras
24
26
import tensorflow .compat .v1 as tf
@@ -40,41 +42,50 @@ class ExampleParser(keras.layers.Layer):
40
42
41
43
def __init__ (self , input_feature_key ):
42
44
self ._input_feature_key = input_feature_key
45
+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
43
46
super ().__init__ ()
44
47
48
+ def compute_output_shape (self , input_shape : Any ):
49
+ return [1 , 1 ]
50
+
45
51
def call (self , serialized_examples ):
46
52
def get_feature (serialized_example ):
47
53
parsed_example = tf .io .parse_single_example (
48
54
serialized_example , features = FEATURE_MAP
49
55
)
50
56
return parsed_example [self ._input_feature_key ]
51
-
57
+ serialized_examples = tf . cast ( serialized_examples , tf . string )
52
58
return tf .map_fn (get_feature , serialized_examples )
53
59
54
60
55
- class ExampleModel (keras .Model ):
56
- """A Example Keras NLP model ."""
61
+ class Reshaper (keras .layers . Layer ):
62
+ """A Keras layer that reshapes the input ."""
57
63
58
- def __init__ (self , input_feature_key ):
59
- super ().__init__ ()
60
- self .parser = ExampleParser (input_feature_key )
61
- self .text_vectorization = keras .layers .TextVectorization (
62
- max_tokens = 32 ,
63
- output_mode = 'int' ,
64
- output_sequence_length = 32 ,
65
- )
66
- self .text_vectorization .adapt (
67
- ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
68
- )
69
- self .dense1 = keras .layers .Dense (32 , activation = 'relu' )
70
- self .dense2 = keras .layers .Dense (1 )
71
-
72
- def call (self , inputs , training = True , mask = None ):
73
- parsed_example = self .parser (inputs )
74
- text_vector = self .text_vectorization (parsed_example )
75
- output1 = self .dense1 (tf .cast (text_vector , tf .float32 ))
76
- output2 = self .dense2 (output1 )
77
- return output2
64
+ def call (self , inputs ):
65
+ return tf .reshape (inputs , (1 , 32 ))
66
+
67
+
68
+ def get_example_model (input_feature_key : str ):
69
+ """Returns a Keras model for testing."""
70
+ parser = ExampleParser (input_feature_key )
71
+ text_vectorization = keras .layers .TextVectorization (
72
+ max_tokens = 32 ,
73
+ output_mode = 'int' ,
74
+ output_sequence_length = 32 ,
75
+ )
76
+ text_vectorization .adapt (
77
+ ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
78
+ )
79
+ dense1 = keras .layers .Dense (32 , activation = 'relu' )
80
+ dense2 = keras .layers .Dense (1 )
81
+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
82
+ parsed_example = parser (inputs )
83
+ text_vector = text_vectorization (parsed_example )
84
+ text_vector = Reshaper ()(text_vector )
85
+ text_vector = tf .cast (text_vector , tf .float32 )
86
+ output1 = dense1 (text_vector )
87
+ output2 = dense2 (output1 )
88
+ return tf .keras .Model (inputs = inputs , outputs = output2 )
78
89
79
90
80
91
def evaluate_model (
@@ -83,6 +94,7 @@ def evaluate_model(
83
94
tfma_eval_result_path ,
84
95
eval_config ,
85
96
):
97
+
86
98
"""Evaluate Model using Tensorflow Model Analysis.
87
99
88
100
Args:
0 commit comments