diff --git a/examples/vision/oxford_pets_image_segmentation.py b/examples/vision/oxford_pets_image_segmentation.py index 79fc81c312..9d99ea7328 100644 --- a/examples/vision/oxford_pets_image_segmentation.py +++ b/examples/vision/oxford_pets_image_segmentation.py @@ -168,7 +168,7 @@ def unpackage_inputs(inputs): """ augmented_train_ds = ( - train_ds.cache() + train_ds .shuffle(BATCH_SIZE * 2) .map(augment_fn, num_parallel_calls=AUTOTUNE) .batch(BATCH_SIZE) @@ -176,7 +176,7 @@ def unpackage_inputs(inputs): .prefetch(buffer_size=tf.data.AUTOTUNE) ) resized_val_ds = ( - val_ds.cache() + val_ds .map(resize_fn, num_parallel_calls=AUTOTUNE) .batch(BATCH_SIZE) .map(unpackage_inputs)