Source code for easycv.datasets.classification.data_sources.utils

import os

from mmcv.runner import get_dist_info
from tqdm import tqdm

from easycv.file import io


[docs]def split_listfile_byrank(list_file, label_balance, save_path='data/', delimeter=' '): rank, world_size = get_dist_info() if world_size == 1: return list_file lines = io.open(list_file).readlines() if not os.path.exists(save_path): os.makedirs(save_path) if label_balance: label2files = {} for l in tqdm(lines): label = l.split(delimeter)[-1].strip() path = l.split(delimeter)[0] if label in label2files.keys(): label2files[label].append(path) else: label2files[label] = [path] label_list = list(label2files.keys()) length = int(len(label_list) / world_size) + 1 for idx in tqdm(range(world_size)): with io.open(os.path.join(save_path, 'rank_%d.txt' % idx), 'w') as f: for j in range(idx * length, idx * length + length): j_label = label_list[j % len(label_list)] for p in label2files[j_label]: f.write('%s %s\n' % (p, j_label)) else: length = int(len(lines) / world_size) + 1 for idx in range(world_size): with io.open(os.path.join(save_path, 'rank_%d.txt' % idx), 'w') as f: for j in range(idx * length, idx * length + length): f.write(lines[j % len(lines)]) return os.path.join(save_path, 'rank_%d.txt' % rank)