Source code for easycv.datasets.classification.data_sources.image_list

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import os

from PIL import Image, ImageFile

from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image
from easycv.framework.errors import TypeError, ValueError
from easycv.utils.dist_utils import dist_zero_exec
from .utils import split_listfile_byrank


[docs]@DATASOURCES.register_module class ClsSourceImageList(object): """ data source for classification Args: list_file : str / list(str), str means a input image list file path, this file contains records as `image_path label` in list_file list(str) means multi image list, each one contains some records as `image_path label` root: str / list(str), root path for image_path, each list_file will need a root, if len(root) < len(list_file), we will use root[-1] to fill root list. delimeter: str, delimeter of each line in the `list_file` split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list. If split, data list will be split to each rank. split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. """
[docs] def __init__(self, list_file, root='', delimeter=' ', split_huge_listfile_byrank=False, split_label_balance=False, cache_path='data/', class_list=None): ImageFile.LOAD_TRUNCATED_IMAGES = True # DistributedMPSampler need this attr self.has_labels = True self.class_list = class_list if self.class_list is None: logging.warning( 'It is recommended to specify the ``class_list`` parameter!') self.label_dict = {} else: self.label_dict = dict( zip(self.class_list, range(len(self.class_list)))) if isinstance(list_file, str): assert isinstance(root, str), 'list_file is str, root must be str' list_file = [list_file] root = [root] else: assert isinstance(list_file, list), \ 'list_file should be str or list(str)' root = [root] if isinstance(root, str) else root if not isinstance(root, list): raise TypeError('root must be str or list(str), but get %s' % type(root)) if len(root) < len(list_file): logging.warning( 'len(root) < len(list_file), fill root with root last!') root = root + [root[-1]] * (len(list_file) - len(root)) # TODO: support return list, donot save split file # TODO: support loading list_file that have already been split if split_huge_listfile_byrank: with dist_zero_exec(): list_file = split_listfile_byrank( list_file=list_file, label_balance=split_label_balance, save_path=cache_path) self.fns = [] self.labels = [] for l, r in zip(list_file, root): fns, labels = self.parse_list_file(l, r, delimeter, self.label_dict) self.fns += fns self.labels += labels
[docs] @staticmethod def parse_list_file(list_file, root, delimeter, label_dict={}): with io.open(list_file, 'r') as f: lines = f.readlines() fns = [] labels = [] for l in lines: splits = l.strip().split(delimeter) if len(root) > 0: fns.append(os.path.join(root, splits[0])) else: fns.append(splits[0]) if len(label_dict) == 0: # must be int,other with mmcv collect will crash label = [int(i) for i in splits[1:]] else: label = [label_dict[i] for i in splits[1:]] labels.append( label[0]) if len(label) == 1 else labels.append(label) return fns, labels
def __len__(self): return len(self.fns) def __getitem__(self, idx): img = load_image(self.fns[idx], mode='RGB') if img is None: return self[idx + 1] img = Image.fromarray(img) label = self.labels[idx] result_dict = {'img': img, 'gt_labels': label} return result_dict
[docs]@DATASOURCES.register_module class ClsSourceItag(ClsSourceImageList): """ data source itag for classification Args: list_file : str / list(str), str means a input image list file path, this file contains records as `image_path label` in list_file list(str) means multi image list, each one contains some records as `image_path label` """
[docs] def __init__(self, list_file, root='', class_list=None): assert root is None or len( root) < 1, 'The "root" param is not used and will be removed soon!' ImageFile.LOAD_TRUNCATED_IMAGES = True # DistributedMPSampler need this attr self.has_labels = True self.class_list = class_list if self.class_list is None: logging.warning( 'It is recommended to specify the ``class_list`` parameter!') self._auto_collect_labels = True self.label_dict = {} else: self.label_dict = dict( zip(self.class_list, range(len(self.class_list)))) self._auto_collect_labels = False self.fns, self.labels, self.label_dict = self.parse_list_file( list_file, self.label_dict, self._auto_collect_labels)
[docs] @staticmethod def parse_list_file(list_file, label_dict, auto_collect_labels=True): with io.open(list_file, 'r') as f: rows = f.read().splitlines() fns = [] labels_id = [] for row_str in rows: data_i = json.loads(row_str.strip()) img_path = data_i['data']['source'] label_id = [] priority = 2 for k in data_i.keys(): if 'verify' in k: priority = 0 break elif 'check' in k: priority = 1 for k, v in data_i.items(): if 'label' in k: label_id = [] result_list = v['results'] for j in range(len(result_list)): label = result_list[j]['data'] if 'labels' in label: label = label['labels']['单选'] if label not in label_dict: if auto_collect_labels: label_dict[label] = len(label_dict) else: raise ValueError( f'Not find label "{label}" in label dict: {label_dict}' ) label_id.append(label_dict[label]) if 'verify' in k: break elif 'check' in k and priority == 1: break fns.append(img_path) labels_id.append(label_id[0]) if len( label_id) == 1 else labels_id.append(label_id) return fns, labels_id, label_dict