Source code for easycv.datasets.classification.data_sources.flower

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from pathlib import Path

from PIL import Image
from scipy.io import loadmat
from torchvision.datasets.utils import (check_integrity,
                                        download_and_extract_archive,
                                        download_url)

from easycv.datasets.registry import DATASOURCES


[docs]@DATASOURCES.register_module class ClsSourceFlowers102(object): _download_url_prefix = 'https://www.robots.ox.ac.uk/~vgg/data/flowers/102/' _file_dict = dict( # filename, md5 image=('102flowers.tgz', '52808999861908f626f3c1f4e79d11fa'), label=('imagelabels.mat', 'e0620be6f572b9609742df49c70aed4d'), setid=('setid.mat', 'a5357ecc9cb78c4bef273ce3793fc85c')) _splits_map = {'train': 'trnid', 'val': 'valid', 'test': 'tstid'}
[docs] def __init__(self, root, split, download=False) -> None: assert split in ['train', 'test', 'val'] self._base_folder = Path(root) / 'flowers-102' self._images_folder = self._base_folder / 'jpg' if download: self.download() # verify that the path exists if not self._check_integrity(): raise FileNotFoundError( f'The data in the {self._base_folder} file directory is incomplete' ) # Data reading in progress set_ids = loadmat( self._base_folder / self._file_dict['setid'][0], squeeze_me=True) image_ids = set_ids[self._splits_map[split]].tolist() labels = loadmat( self._base_folder / self._file_dict['label'][0], squeeze_me=True) image_id_to_label = dict(enumerate((labels['labels'] - 1).tolist(), 1)) self._labels = [] self._image_files = [] for image_id in image_ids: self._labels.append(image_id_to_label[image_id]) self._image_files.append(self._images_folder / f'image_{image_id:05d}.jpg')
def __len__(self): return len(self._labels) def __getitem__(self, idx): image_file, label = self._image_files[idx], self._labels[idx] img = Image.open(image_file).convert('RGB') result_dict = {'img': img, 'gt_labels': label} return result_dict # verify that the path exists def _check_integrity(self): if not (self._images_folder.exists() and self._images_folder.is_dir()): return False for id in ['label', 'setid']: filename, md5 = self._file_dict[id] if not check_integrity(str(self._base_folder / filename), md5): return False return True
[docs] def download(self): os.makedirs(self._base_folder, exist_ok=True) if self._check_integrity(): return # Download and extract download_and_extract_archive( f"{self._download_url_prefix}{self._file_dict['image'][0]}", str(self._base_folder), md5=self._file_dict['image'][1], remove_finished=True) for id in ['label', 'setid']: filename, md5 = self._file_dict[id] download_url( self._download_url_prefix + filename, str(self._base_folder), md5=md5)