# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import random
from PIL import Image, ImageFile
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image
from easycv.utils.dist_utils import dist_zero_exec
from .utils import split_listfile_byrank
[docs]@DATASOURCES.register_module
class ClsSourceImageListByClass(object):
"""
Get the same `m_per_class` samples by the label idx.
Args:
list_file : str / list(str), str means a input image list file path,
this file contains records as `image_path label` in list_file
list(str) means multi image list, each one contains some records as `image_path label`
root: str / list(str), root path for image_path, each list_file will need a root.
m_per_class: num of samples for each class.
delimeter: str, delimeter of each line in the `list_file`
split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list.
If split, data list will be split to each rank.
cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path.
max_try: int, max try numbers of reading image
"""
[docs] def __init__(self,
root,
list_file,
m_per_class=2,
delimeter=' ',
split_huge_listfile_byrank=False,
cache_path='data/',
max_try=20):
ImageFile.LOAD_TRUNCATED_IMAGES = True
# TODO: support return list, donot save split file
# TODO: support loading list_file that have already been split
if split_huge_listfile_byrank:
with dist_zero_exec():
list_file = split_listfile_byrank(
list_file=list_file,
label_balance=True,
save_path=cache_path)
with io.open(list_file, 'r') as f:
lines = f.readlines()
self.m_per_class = m_per_class
self.has_labels = len(lines[0].split(delimeter)) >= 2
assert self.has_labels is True
label2files = {}
for l in lines:
label = int(l.strip().split(delimeter)[1])
path = l.strip().split(delimeter)[0]
if label in label2files.keys():
label2files[label].append(path)
else:
label2files[label] = [path]
self.labels = list(label2files.keys())
self.fns_by_labels = [label2files[i] for i in self.labels]
self.root = root
self.initialized = False
self.max_try = max_try
def __len__(self):
return len(self.fns_by_labels)
def __getitem__(self, idx):
label = self.labels[idx]
image_list = self.fns_by_labels[idx]
if len(image_list) < 1:
logging.info('%s :image list contain < 1 image' % idx)
return self[idx + 1]
if self.m_per_class > len(image_list):
image_list = int(self.m_per_class / len(image_list) +
1) * image_list
sample_list = random.sample(image_list, self.m_per_class)
return_img = []
return_label = []
for path in sample_list:
img = load_image(os.path.join(self.root, path), mode='RGB')
if img is None:
return self[idx + 1]
img = Image.fromarray(img)
return_img.append(img)
return_label.append(label)
result_dict = {'img': return_img, 'gt_labels': return_label}
return result_dict