Skip to content

Commit 3525d1a

Browse files
vkarampudiResponsible ML Infra Team
authored andcommitted
Upgrade tensorflow-model-analysis and tensorflow to >=2.16.2.
PiperOrigin-RevId: 705990022
1 parent 8789c99 commit 3525d1a

File tree

4 files changed

+54
-37
lines changed

4 files changed

+54
-37
lines changed

fairness_indicators/example_model.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,6 +20,7 @@
1920
results can be visualized using tools like TensorBoard.
2021
"""
2122

23+
from typing import Any
2224
from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
2325
from tensorflow import keras
2426
import tensorflow.compat.v1 as tf
@@ -40,41 +42,50 @@ class ExampleParser(keras.layers.Layer):
4042

4143
def __init__(self, input_feature_key):
4244
self._input_feature_key = input_feature_key
45+
self.input_spec = keras.layers.InputSpec(shape=(1,), dtype=tf.string)
4346
super().__init__()
4447

48+
def compute_output_shape(self, input_shape: Any):
49+
return [1, 1]
50+
4551
def call(self, serialized_examples):
4652
def get_feature(serialized_example):
4753
parsed_example = tf.io.parse_single_example(
4854
serialized_example, features=FEATURE_MAP
4955
)
5056
return parsed_example[self._input_feature_key]
51-
57+
serialized_examples = tf.cast(serialized_examples, tf.string)
5258
return tf.map_fn(get_feature, serialized_examples)
5359

5460

55-
class ExampleModel(keras.Model):
56-
"""A Example Keras NLP model."""
61+
class Reshaper(keras.layers.Layer):
62+
"""A Keras layer that reshapes the input."""
5763

58-
def __init__(self, input_feature_key):
59-
super().__init__()
60-
self.parser = ExampleParser(input_feature_key)
61-
self.text_vectorization = keras.layers.TextVectorization(
62-
max_tokens=32,
63-
output_mode='int',
64-
output_sequence_length=32,
65-
)
66-
self.text_vectorization.adapt(
67-
['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random']
68-
)
69-
self.dense1 = keras.layers.Dense(32, activation='relu')
70-
self.dense2 = keras.layers.Dense(1)
71-
72-
def call(self, inputs, training=True, mask=None):
73-
parsed_example = self.parser(inputs)
74-
text_vector = self.text_vectorization(parsed_example)
75-
output1 = self.dense1(tf.cast(text_vector, tf.float32))
76-
output2 = self.dense2(output1)
77-
return output2
64+
def call(self, inputs):
65+
return tf.reshape(inputs, (1, 32))
66+
67+
68+
def get_example_model(input_feature_key: str):
69+
"""Returns a Keras model for testing."""
70+
parser = ExampleParser(input_feature_key)
71+
text_vectorization = keras.layers.TextVectorization(
72+
max_tokens=32,
73+
output_mode='int',
74+
output_sequence_length=32,
75+
)
76+
text_vectorization.adapt(
77+
['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random']
78+
)
79+
dense1 = keras.layers.Dense(32, activation='relu')
80+
dense2 = keras.layers.Dense(1)
81+
inputs = tf.keras.Input(shape=(), dtype=tf.string)
82+
parsed_example = parser(inputs)
83+
text_vector = text_vectorization(parsed_example)
84+
text_vector = Reshaper()(text_vector)
85+
text_vector = tf.cast(text_vector, tf.float32)
86+
output1 = dense1(text_vector)
87+
output2 = dense2(output1)
88+
return tf.keras.Model(inputs=inputs, outputs=output2)
7889

7990

8091
def evaluate_model(
@@ -83,6 +94,7 @@ def evaluate_model(
8394
tfma_eval_result_path,
8495
eval_config,
8596
):
97+
8698
"""Evaluate Model using Tensorflow Model Analysis.
8799
88100
Args:

fairness_indicators/example_model_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,15 @@ def _write_tf_records(self, examples):
8181

8282
def test_example_model(self):
8383
data = self._create_data()
84-
classifier = example_model.ExampleModel(example_model.TEXT_FEATURE)
84+
classifier = example_model.get_example_model(example_model.TEXT_FEATURE)
8585
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
8686
classifier.fit(
8787
tf.constant([e.SerializeToString() for e in data]),
8888
np.array([
8989
e.features.feature[example_model.LABEL].float_list.value[:][0]
9090
for e in data
9191
]),
92+
batch_size=1,
9293
)
9394
classifier.save(self._model_dir, save_format='tf')
9495

setup.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,20 @@ def select_constraint(default, nightly=None, git_master=None):
3838
return default
3939

4040
REQUIRED_PACKAGES = [
41-
'tensorflow>=2.15,<2.16',
41+
'tensorflow>=2.16.2,<2.17.0',
4242
'tensorflow-hub>=0.16.1,<1.0.0',
43-
'tensorflow-data-validation' + select_constraint(
44-
default='>=1.15.1,<2.0.0',
45-
nightly='>=1.16.0.dev',
46-
git_master='@git+https://github.com/tensorflow/data-validation@master'),
47-
'tensorflow-model-analysis' + select_constraint(
48-
default='>=0.46,<0.47',
49-
nightly='>=0.47.0.dev',
50-
git_master='@git+https://github.com/tensorflow/model-analysis@master'),
43+
'tensorflow-data-validation'
44+
+ select_constraint(
45+
default='>=1.16.1,<2.0.0',
46+
nightly='>=1.17.0.dev',
47+
git_master='@git+https://github.com/tensorflow/data-validation@master',
48+
),
49+
'tensorflow-model-analysis'
50+
+ select_constraint(
51+
default='>=0.47,<0.48',
52+
nightly='>=0.48.0.dev',
53+
git_master='@git+https://github.com/tensorflow/model-analysis@master',
54+
),
5155
'witwidget>=1.4.4,<2',
5256
'protobuf>=3.20.3,<5',
5357
]

tensorboard_plugin/setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def select_constraint(default, nightly=None, git_master=None):
4343

4444
REQUIRED_PACKAGES = [
4545
'protobuf>=3.20.3,<5',
46-
'tensorboard>=2.15.2,<2.16.0',
47-
'tensorflow>=2.15,<2.16',
46+
'tensorboard>=2.16.2,<2.17.0',
47+
'tensorflow>=2.16.2,<2.17.0',
4848
'tensorflow-model-analysis'
4949
+ select_constraint(
50-
default='>=0.46,<0.47',
51-
nightly='>=0.47.0.dev',
50+
default='>=0.47,<0.48',
51+
nightly='>=0.48.0.dev',
5252
git_master='@git+https://github.com/tensorflow/model-analysis@master',
5353
),
5454
'werkzeug<2',

0 commit comments

Comments
 (0)