Source code for easycv.datasets.classification.data_sources.class_list

# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import random

from PIL import Image, ImageFile

from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image
from easycv.utils.dist_utils import dist_zero_exec
from .utils import split_listfile_byrank


[docs]@DATASOURCES.register_module class ClsSourceImageListByClass(object): """ Get the same `m_per_class` samples by the label idx. Args: list_file : str / list(str), str means a input image list file path, this file contains records as `image_path label` in list_file list(str) means multi image list, each one contains some records as `image_path label` root: str / list(str), root path for image_path, each list_file will need a root. m_per_class: num of samples for each class. delimeter: str, delimeter of each line in the `list_file` split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list. If split, data list will be split to each rank. cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. max_try: int, max try numbers of reading image """
[docs] def __init__(self, root, list_file, m_per_class=2, delimeter=' ', split_huge_listfile_byrank=False, cache_path='data/', max_try=20): ImageFile.LOAD_TRUNCATED_IMAGES = True # TODO: support return list, donot save split file # TODO: support loading list_file that have already been split if split_huge_listfile_byrank: with dist_zero_exec(): list_file = split_listfile_byrank( list_file=list_file, label_balance=True, save_path=cache_path) with io.open(list_file, 'r') as f: lines = f.readlines() self.m_per_class = m_per_class self.has_labels = len(lines[0].split(delimeter)) >= 2 assert self.has_labels is True label2files = {} for l in lines: label = int(l.strip().split(delimeter)[1]) path = l.strip().split(delimeter)[0] if label in label2files.keys(): label2files[label].append(path) else: label2files[label] = [path] self.labels = list(label2files.keys()) self.fns_by_labels = [label2files[i] for i in self.labels] self.root = root self.initialized = False self.max_try = max_try
def __len__(self): return len(self.fns_by_labels) def __getitem__(self, idx): label = self.labels[idx] image_list = self.fns_by_labels[idx] if len(image_list) < 1: logging.info('%s :image list contain < 1 image' % idx) return self[idx + 1] if self.m_per_class > len(image_list): image_list = int(self.m_per_class / len(image_list) + 1) * image_list sample_list = random.sample(image_list, self.m_per_class) return_img = [] return_label = [] for path in sample_list: img = load_image(os.path.join(self.root, path), mode='RGB') if img is None: return self[idx + 1] img = Image.fromarray(img) return_img.append(img) return_label.append(label) result_dict = {'img': return_img, 'gt_labels': return_label} return result_dict