[go: nahoru, domu]

Skip to content

Commit

Permalink
datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
samyuh committed Jun 1, 2022
1 parent 554698e commit 88a005e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
8 changes: 4 additions & 4 deletions Project 2/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.metrics import accuracy_score
from torchvision import transforms, models
import matplotlib.pyplot as plt
from dataset import TrafficSignsDataset
from dataset import ImageClassificationDataset

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
Expand Down Expand Up @@ -207,9 +207,9 @@ def test(self, best_model, test_dl, loss_fn):
train_images = list(val_train_images[:train_ratio])
validation_images = list(val_train_images[-validation_ratio:])

train_data = TrafficSignsDataset(train_images, train_transform)
validation_data = TrafficSignsDataset(validation_images, validation_transform)
test_data = TrafficSignsDataset(test_images, test_transform)
train_data =ImageClassificationDataset(train_images, train_transform)
validation_data = ImageClassificationDataset(validation_images, validation_transform)
test_data = ImageClassificationDataset(test_images, test_transform)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True, drop_last=True)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=16, shuffle=False, drop_last=False)
Expand Down
55 changes: 54 additions & 1 deletion Project 2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,60 @@
ANNOTATIONS_DIR = './annotations/'
CLASSES = ['trafficlight', 'stop', 'speedlimit', 'crosswalk']

class TrafficSignsDataset(Dataset):
class ImageClassificationDataset(Dataset):
def __init__(self, images, transform=None):
self.images = pd.DataFrame(images, columns=['image_name'])
self.transform = transform

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
# image = io.imread(IMAGES_DIR + self.images.iloc[idx, 0] + '.png', as_gray=True)
# image = cv2.imread(f'{IMAGES_DIR}{self.images.iloc[idx, 0]}.png')
# image = np.array(image, dtype=np.uint8).reshape((280, 280))
# image = Image.fromarray(image, mode='L')

# image = Image.open(os.path.join(IMAGES_DIR, self.images.iloc[idx, 0] + '.png'))
# image = image.resize((200, 200))
# image = ImageOps.grayscale(image)

image = cv2.imread(f'{IMAGES_DIR}{self.images.iloc[idx, 0]}.png')
try:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # TODO: COLOR_BGR2RGB?
except:
print(f'Error reading image {self.images.iloc[idx, 0]}.png')
return None
if self.transform:
image = self.transform(image)

tree = ET.parse(ANNOTATIONS_DIR + f'{self.images.iloc[idx, 0]}.xml')
correct_labels = [movie.text for movie in tree.getroot().iter('name')]

labels = []

# TODO: Detect multiple classes, not just one
#for cl in CLASSES:
# labels.append(1) if cl in correct_labels else labels.append(0)

teste = {'trafficlight': 0, 'stop': 1, 'speedlimit': 2, 'crosswalk': 3}

#labels.append(1) if "speedlimit" in correct_labels else labels.append(0)
if correct_labels:
labels = teste[correct_labels[0]]

labels = np.asarray(labels)
labels = torch.from_numpy(labels.astype('long'))

result = {
'name': self.images.iloc[idx, 0],
'image': image.float(),
'labels': labels.float()
}

return result

class ImageMultiLabelDataset(Dataset):
def __init__(self, images, transform=None):
self.images = pd.DataFrame(images, columns=['image_name'])
self.transform = transform
Expand Down

0 comments on commit 88a005e

Please sign in to comment.