19
19
results can be visualized using tools like TensorBoard.
20
20
"""
21
21
22
+ from typing import Any
23
+
22
24
from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
23
25
from tensorflow import keras
24
26
import tensorflow .compat .v1 as tf
@@ -40,41 +42,58 @@ class ExampleParser(keras.layers.Layer):
40
42
41
43
def __init__ (self , input_feature_key ):
42
44
self ._input_feature_key = input_feature_key
45
+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
43
46
super ().__init__ ()
44
47
48
+ def compute_output_shape (self , input_shape : Any ):
49
+ return [1 , 1 ]
50
+
45
51
def call (self , serialized_examples ):
46
52
def get_feature (serialized_example ):
47
53
parsed_example = tf .io .parse_single_example (
48
54
serialized_example , features = FEATURE_MAP
49
55
)
50
56
return parsed_example [self ._input_feature_key ]
51
-
57
+ serialized_examples = tf . cast ( serialized_examples , tf . string )
52
58
return tf .map_fn (get_feature , serialized_examples )
53
59
54
60
55
- class ExampleModel (keras .Model ):
56
- """A Example Keras NLP model ."""
61
+ class Reshaper (keras .layers . Layer ):
62
+ """A Keras layer that reshapes the input ."""
57
63
58
- def __init__ (self , input_feature_key ):
59
- super ().__init__ ()
60
- self .parser = ExampleParser (input_feature_key )
61
- self .text_vectorization = keras .layers .TextVectorization (
62
- max_tokens = 32 ,
63
- output_mode = 'int' ,
64
- output_sequence_length = 32 ,
65
- )
66
- self .text_vectorization .adapt (
67
- ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
68
- )
69
- self .dense1 = keras .layers .Dense (32 , activation = 'relu' )
70
- self .dense2 = keras .layers .Dense (1 )
71
-
72
- def call (self , inputs , training = True , mask = None ):
73
- parsed_example = self .parser (inputs )
74
- text_vector = self .text_vectorization (parsed_example )
75
- output1 = self .dense1 (tf .cast (text_vector , tf .float32 ))
76
- output2 = self .dense2 (output1 )
77
- return output2
64
+ def call (self , inputs ):
65
+ return tf .reshape (inputs , (1 , 32 ))
66
+
67
+
68
+ class Caster (keras .layers .Layer ):
69
+ """A Keras layer that reshapes the input."""
70
+
71
+ def call (self , inputs ):
72
+ return tf .cast (inputs , tf .float32 )
73
+
74
+
75
+ def get_example_model (input_feature_key : str ):
76
+ """Returns a Keras model for testing."""
77
+ parser = ExampleParser (input_feature_key )
78
+ text_vectorization = keras .layers .TextVectorization (
79
+ max_tokens = 32 ,
80
+ output_mode = 'int' ,
81
+ output_sequence_length = 32 ,
82
+ )
83
+ text_vectorization .adapt (
84
+ ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
85
+ )
86
+ dense1 = keras .layers .Dense (32 , activation = 'relu' )
87
+ dense2 = keras .layers .Dense (1 )
88
+
89
+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
90
+ parsed_example = parser (inputs )
91
+ text_vector = text_vectorization (parsed_example )
92
+ text_vector = Reshaper ()(text_vector )
93
+ text_vector = Caster ()(text_vector )
94
+ output1 = dense1 (text_vector )
95
+ output2 = dense2 (output1 )
96
+ return tf .keras .Model (inputs = inputs , outputs = output2 )
78
97
79
98
80
99
def evaluate_model (
0 commit comments