- 
                Notifications
    
You must be signed in to change notification settings  - Fork 333
 
Description
The images and bounding boxes display properly.
When I attempt to any augment function the following stack trace is produced:
I must have something setup incorrectly, however the visualize dataset function works.
How do i determine where the the issue/problem is?
Following this documentation:
https://keras.io/guides/keras_cv/object_detection_keras_cv/#training-our-model
    Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: <unknown>
    
    Arguments received by AutoContrast.call():
      • inputs={'images': "<tf.Tensor 'args_2:0' shape=<unknown> dtype=float32>", 'bounding_boxes': {'classes': 'tf.RaggedTensor(values=Tensor("RaggedFromVariant_1/RaggedTensorFromVariant:1", shape=(None,), dtype=float32), row_splits=Tensor("RaggedFromVariant_1/RaggedTensorFromVariant:0", shape=(5,), dtype=int64))', 'boxes': 'tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("Cast:0", shape=(None,), dtype=float32), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:1", shape=(None,), dtype=int64)), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(5,), dtype=int64))'}}
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_image(image, channels=3)
return image
def load_dataset(image_path, classes, bbox):
# Read Image
image = load_image(image_path)
bounding_boxes = {
"classes": tf.cast(classes, dtype=tf.float32),
"boxes": bbox,
}
return {"images": tf.cast(image, tf.float32), "bounding_boxes": bounding_boxes}
augmenters = [
keras_cv.layers.AutoContrast((0, 255)),
]
def create_augmenter_fn(augmenters):
def augmenter_fn(inputs):
for augmenter in augmenters:
inputs = augmenter(inputs)
return inputs
return augmenter_fn
augmenter_fn = create_augmenter_fn(augmenters)
train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(BATCH_SIZE * 4)
train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
train_ds = train_ds.map(augmenter_fn, num_parallel_calls=tf_data.AUTOTUNE)