Source code for easycv.datasets.shared.multi_view

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy

from PIL import Image

from easycv.datasets.builder import build_datasource
from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.framework.errors import NotImplementedError
from easycv.utils.registry import build_from_cfg

[docs]@DATASETS.register_module class MultiViewDataset(BaseDataset): """The dataset outputs multiple views of an image. The number of views in the output dict depends on `num_views`. The image can be processed by one pipeline or multiple piepelines. Args: num_views (list): The number of different views. pipelines (list[list[dict]]): A list of pipelines. """
[docs] def __init__(self, data_source, num_views, pipelines): self.data_source = build_datasource(data_source) pipelines_list = [] for pipe in pipelines: pipeline = Compose([build_from_cfg(p, PIPELINES) for p in pipe]) pipelines_list.append(pipeline) self.transforms_list = [] assert isinstance(num_views, list) for i in range(len(num_views)): self.transforms_list.extend([pipelines_list[i]] * num_views[i])
def __getitem__(self, idx): results = self.data_source[idx] img = results['img'] assert isinstance(img, Image.Image), \ f'The output from the data source must be an Image, got: {type(img)}. \ Please ensure that the list file does not contain labels.' imgs_list = [] # only perform transforms to img for trans in self.transforms_list: tmp_input = {'img': copy.deepcopy(img)} tmp_result = trans(tmp_input) imgs_list.append(tmp_result['img']) results['img'] = imgs_list return results
[docs] def evaluate(self, results, evaluators, logger=None): raise NotImplementedError