Source code for easycv.datasets.classification.data_sources.imagenet

# Copyright (c) Alibaba, Inc. and its affiliates.

import os

from PIL import Image

from easycv.datasets.registry import DATASOURCES
from easycv.file.image import load_image


def generate_label_map(label_path):
    label_txt = open(label_path).readlines()
    assert len(label_txt) > 1, f'{label_path} is None'
    label_map = dict()
    for line in label_txt:
        deal_line = line.strip().split()
        label_map[deal_line[0]] = (deal_line[1], deal_line[2])
    return label_map


def get_images_list(txt_path):
    assert os.path.exists(txt_path), f'{txt_path} is not exists'
    train_txt = open(txt_path).readlines()
    assert len(txt_path) > 1, f'{txt_path} is None'
    return train_txt


[docs]@DATASOURCES.register_module class ClsSourceImageNet1k(object):
[docs] def __init__(self, root, split): """ Args: root: The root directory of the data example:if data/imagenet └── train └── n01440764 └── n01443537 └── ... └── val └── n01440764 └── n01443537 └── ... └── meta ├── train.txt ├── val.txt ├── ... has input root = data/imagenet split : train or val """ self.root = root assert split in ['train', 'test'] if split == 'train': self.split = 'meta/train.txt' else: self.split = 'meta/val.txt' self.txt_path = get_images_list(os.path.join(root, self.split)) self.label_json = generate_label_map( os.path.join(root, 'meta/label_name.txt'))
[docs] def read_data(self, image_path): img_path = os.path.join(self.root, image_path.strip()) assert os.path.exists(img_path), f'{img_path} is not exists' img = load_image(img_path, mode='RGB') img = Image.fromarray(img) label_key = image_path.split('/')[1] assert label_key in self.label_json, f'{label_key} label is not exists' label = self.label_json[label_key] return {'img': img, 'gt_labels': int(label[0].strip())}
def __len__(self): return len(self.txt_path) def __getitem__(self, idx): return self.read_data(self.txt_path[idx])