-
Notifications
You must be signed in to change notification settings - Fork 361
Description
While executing the below code:
for input_example_batch, target_example_batch in dataset.take(1): example_batch_predictions = model(input_example_batch) print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
The following error is encountered:
InvalidArgumentError Traceback (most recent call last)
Cell In[42], line 3
1 for input_example_batch, target_example_batch in dataset.take(1):
2 print(input_example_batch.shape)
----> 3 example_batch_predictions = model(input_example_batch)
4 print(
5 example_batch_predictions.shape,
6 "# (batch_size, sequence_length, vocab_size)",
7 )
File ~/Desktop/coursera/venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 # keras.config.disable_traceback_filtering()
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
Cell In[40], line 17, in MyModel.call(self, inputs, states, return_state, training)
13 # since we are training a text generation model,
14 # we use the previous state, in training. If there is no state,
15 # then we initialize the state
16 if states is None:
---> 17 states = self.gru.get_initial_state(x)
18 x, states = self.gru(x, initial_state=states, training=training)
19 x = self.dense(x, training=training)
InvalidArgumentError: Exception encountered when calling MyModel.call().
{{function_node _wrapped__Pack_N_2_device/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name:
Arguments received by MyModel.call():
• inputs=tf.Tensor(shape=(64, 100), dtype=int64)
• states=None
• return_state=False
• training=False
Kindly suggest.