-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_dataset.py
68 lines (50 loc) · 1.78 KB
/
tf_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def rescaling(img, label):
return tf.divide(tf.cast(img, tf.float32), 255.), label
"""
x = [[2, 2],
[3, 3]]
tf.tiles(x, multiples = (2, 2)
|x x|
| |
| |
|x x|
---------------
"""
def augmentation(img, label):
img = tf.image.random_brightness(img, max_delta= 0.1)
#RGP image ----> [x, y, 3]; [x, y, 1] after repeating image for 3 times we get [x, y, 3]
if tf.random.uniform(shape=(), minval=0, maxval=1) <= 0.2:
grayscale = tf.image.rgb_to_grayscale(img)
img = tf.tile(grayscale, multiples=[1, 1, 1, 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_contrast(img, lower = 0.2, upper= 0.4)
return img, label
if __name__ == "__main__":
root_dir = r'C:\Users\Yash PC\PycharmProjects\tftuts\data\train'
train = tf.keras.utils.image_dataset_from_directory(
root_dir,
seed=1234,
validation_split=0.2,
subset='training'
)
val = tf.keras.utils.image_dataset_from_directory(
root_dir,
seed=1234,
validation_split=0.2,
subset='validation'
)
# for img, label in train.take(1):
# print(img.shape, label.shape)
AUTOTUNE = tf.data.AUTOTUNE
BUFFER_SIZE = 1000
train = train.map(rescaling,num_parallel_calls=AUTOTUNE)
val = val.map(rescaling,num_parallel_calls=AUTOTUNE)
train = train.map(augmentation, num_parallel_calls=AUTOTUNE)
val = val.map(augmentation, num_parallel_calls=AUTOTUNE)
train = train.cache().shuffle(BUFFER_SIZE).prefetch(AUTOTUNE)
val = val.cache().shuffle(BUFFER_SIZE).prefetch(AUTOTUNE)
for img, label in train.take(1):
print(img.shape, label.shape)