diff --git a/examples/vision/oxford_pets_image_segmentation.py b/examples/vision/oxford_pets_image_segmentation.py index 9d99ea7328..617d1a4f36 100644 --- a/examples/vision/oxford_pets_image_segmentation.py +++ b/examples/vision/oxford_pets_image_segmentation.py @@ -90,13 +90,13 @@ format. """ -key_rename_fn = lambda inputs: { +rescale_images_and_correct_masks = lambda inputs: { "images": tf.cast(inputs["image"], dtype=tf.float32) / 255.0, "segmentation_masks": inputs["segmentation_mask"] - 1, } -train_ds = orig_train_ds.map(key_rename_fn, num_parallel_calls=AUTOTUNE) -val_ds = orig_val_ds.map(key_rename_fn, num_parallel_calls=AUTOTUNE) +train_ds = orig_train_ds.map(rescale_images_and_correct_masks, num_parallel_calls=AUTOTUNE) +val_ds = orig_val_ds.map(rescale_images_and_correct_masks, num_parallel_calls=AUTOTUNE) """ ## Utility Function