Source code for easycv.datasets.shared.dali_tfrecord_multi_view

# Copyright (c) Alibaba, Inc. and its affiliates.
import math

from mmcv.runner import get_dist_info

from easycv.datasets.builder import build_datasource
from easycv.datasets.loader.sampler import DistributedSampler
from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.utils import dist_utils
from easycv.utils.registry import build_from_cfg


[docs]class DaliLoaderWrapper(object):
[docs] def __init__(self, dali_loader, batch_size, return_list): self.dali_loader = dali_loader self.batch_size = batch_size self.return_list = return_list self.loader = None rank, world_size = get_dist_info() self.rank = rank self.sampler = DistributedSampler( [1] * 100, world_size, rank, shuffle=True, replace=False)
def __len__(self): return math.ceil(self.dali_loader._size / self.batch_size) def __iter__(self): self.loader = iter(self.dali_loader) return self def __next__(self): try: data = next(self.loader) except StopIteration: self.dali_loader.reset() raise StopIteration img = [ data[0][self.return_list[i]] for i in range(len(self.return_list) - 1) ] return dict(img=img)
def _load_DaliTFRecordMultiViewPipe(): """Avoid import nvidia errors when not use. """ import nvidia.dali.ops as ops from nvidia.dali.pipeline import Pipeline from easycv.datasets.utils.tfrecord_util import get_imagenet_dali_tfrecord_feature class DaliTFRecordMultiViewPipe(Pipeline): def __init__(self, data_source, transforms_list, batch_size, distributed, random_shuffle=True, workers_per_gpu=2, device='gpu'): self.device = device if distributed: self.rank, self.world_size = get_dist_info() self.local_rank = dist_utils.local_rank() self.local_size = dist_utils.get_num_gpu_per_node() else: self.rank, self.local_rank, self.world_size = 0, 0, 1 super(DaliTFRecordMultiViewPipe, self).__init__( batch_size, workers_per_gpu, self.local_rank, seed=12 + self.rank) imagenet_feature = get_imagenet_dali_tfrecord_feature() self.input = ops.TFRecordReader( path=data_source.data_list, index_path=data_source.index_list, shard_id=self.rank, num_shards=self.world_size, random_shuffle=random_shuffle, features=imagenet_feature) self.transforms_list = transforms_list def define_graph(self): inputs = self.input(name='Reader') jpegs, labels = inputs['image/encoded'], inputs[ 'image/class/label'] output_list = [] for t_i in self.transforms_list: output_list.append(t_i(jpegs)) if self.device == 'gpu': labels = labels.gpu() output_list.append(labels) return output_list return DaliTFRecordMultiViewPipe
[docs]@DATASETS.register_module class DaliTFRecordMultiViewDataset(object): """Adapt to dali, the dataset outputs multiple views of an image. The number of views in the output dict depends on `num_views`. The image can be processed by one pipeline or multiple piepelines. Args: num_views (list): The number of different views. pipelines (list[list[dict]]): A list of pipelines. """
[docs] def __init__(self, data_source, num_views, pipelines, distributed, batch_size, random_shuffle=True, workers_per_gpu=2): assert len(num_views) == len(pipelines) self.batch_size = batch_size if distributed: self.rank, self.world_size = get_dist_info() else: self.rank, self.world_size = 0, 1 data_source = build_datasource(data_source) transforms = [] for pipeline in pipelines: transform = Compose( [build_from_cfg(p, PIPELINES) for p in pipeline]) transforms.append(transform) self.transforms_list = [] assert isinstance(num_views, list) for i in range(len(num_views)): self.transforms_list.extend([transforms[i]] * num_views[i]) DaliTFRecordMultiViewPipe = _load_DaliTFRecordMultiViewPipe() self.pipeline = DaliTFRecordMultiViewPipe( data_source=data_source, transforms_list=self.transforms_list, batch_size=batch_size, distributed=distributed, random_shuffle=random_shuffle, workers_per_gpu=workers_per_gpu)
[docs] def get_dataloader(self): from nvidia.dali.plugin.pytorch import DALIGenericIterator batch_size = self.batch_size self.pipeline.build() data_size = int(self.pipeline.epoch_size('Reader') / self.world_size) return_list = [ 'image_%s' % i for i in range(len(self.transforms_list)) ] + ['label'] data_loader = DALIGenericIterator([self.pipeline], return_list, data_size, fill_last_batch=True) data_loader = DaliLoaderWrapper(data_loader, batch_size, return_list) return data_loader