Source code for easycv.datasets.shared.base

# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from abc import ABCMeta, abstractmethod

from torch.utils.data import Dataset

from easycv.utils.registry import build_from_cfg
from ..builder import build_datasource
from ..registry import PIPELINES
from .pipelines.transforms import Compose


[docs]class BaseDataset(Dataset, metaclass=ABCMeta): """Base Dataset """
[docs] def __init__(self, data_source, pipeline, profiling=False): self.data_source = build_datasource(data_source) pipeline = [build_from_cfg(p, PIPELINES) for p in pipeline] self.pipeline = Compose(pipeline, profiling=profiling)
def __len__(self): return len(self.data_source) def _rand_another(self, idx): """Randomly get another item. Returns: int: Another index of item. """ idx = random.randint(0, len(self) - 1) return idx @abstractmethod def __getitem__(self, idx): pass
[docs] @abstractmethod def evaluate(self, results, evaluators, logger=None, **kwargs): pass
[docs] def visualize(self, results, **kwargs): """Visulaize the model output results on validation data. Returns: A dictionary If add image visualization, return dict containing images: List of visulaized images. img_metas: List of length number of test images, dict of image meta info, containing filename, img_shape, origin_img_shape, scale_factor and so on. """ return {}