diff --git a/beginner_source/basics/data_tutorial.py b/beginner_source/basics/data_tutorial.py index 2baef464..c82cd2b0 100755 --- a/beginner_source/basics/data_tutorial.py +++ b/beginner_source/basics/data_tutorial.py @@ -125,7 +125,7 @@ class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): - self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label']) + self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform