Source code for easycv.datasets.classification.data_sources.mnist

# Copyright (c) Alibaba, Inc. and its affiliates.
from PIL import Image
from torchvision.datasets import MNIST, FashionMNIST

from easycv.datasets.registry import DATASOURCES


[docs]@DATASOURCES.register_module class ClsSourceMnist(object):
[docs] def __init__(self, root, split, download=True): assert split in ['train', 'test'] self.mnist = MNIST( root=root, train=(split == 'train'), download=download) self.labels = self.mnist.targets # data label_classes self.CLASSES = self.mnist.classes
def __len__(self): return len(self.mnist) def __getitem__(self, idx): # img: HWC, RGB img = Image.fromarray(self.mnist.data[idx].numpy()).convert('RGB') label = self.labels[idx].item() result_dict = {'img': img, 'gt_labels': label} return result_dict
[docs]@DATASOURCES.register_module class ClsSourceFashionMnist(object):
[docs] def __init__(self, root, split, download=True): assert split in ['train', 'test'] self.fashion_mnist = FashionMNIST( root=root, train=(split == 'train'), download=download) self.labels = self.fashion_mnist.targets # data label_classes self.CLASSES = self.fashion_mnist.classes
def __len__(self): return len(self.fashion_mnist) def __getitem__(self, idx): # img: HWC, RGB img = Image.fromarray( self.fashion_mnist.data[idx].numpy()).convert('RGB') label = self.labels[idx].item() result_dict = {'img': img, 'gt_labels': label} return result_dict