diff --git a/beginner_source/basics/data_tutorial.py b/beginner_source/basics/data_tutorial.py index 0ef1fb6b777..d12f275c32a 100644 --- a/beginner_source/basics/data_tutorial.py +++ b/beginner_source/basics/data_tutorial.py @@ -160,7 +160,7 @@ def __getitem__(self, idx): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): - self.img_labels = pd.read_csv(annotations_file) + self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label']) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform