Source code for easycv.datasets.shared.data_sources.image_npy

# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import cv2
import numpy as np
from PIL import Image

from easycv.datasets.registry import DATASOURCES
from easycv.file import io, is_oss_path
from easycv.utils.dist_utils import dist_zero_exec


[docs]@DATASOURCES.register_module class ImageNpy(object):
[docs] def __init__(self, image_file, label_file=None, cache_root='data_cache/'): """ image_file: (local or oss) image data saved in one .npy data [cv2.img, cv2.img,...] label_file: (local or oss) label data saved in one .npy data """ if is_oss_path(image_file): with dist_zero_exec(): dst_path = os.path.join(cache_root, image_file) io.copy(image_file, dst_path) image_file = dst_path self.has_labels = label_file != None self.labels = None if label_file: if is_oss_path(label_file): with dist_zero_exec(): dst_path = os.path.join(cache_root, label_file) io.copy(label_file, dst_path) label_file = dst_path self.labels = np.load(label_file, allow_pickle=True) self.data = np.load(image_file, allow_pickle=True)
def __len__(self): return self.data.shape[0] def __getitem__(self, idx): img = Image.fromarray(cv2.cvtColor(self.data[idx], cv2.COLOR_BGR2RGB)) results = {'img': img} if self.labels is not None: label = self.labels[idx] results.update({'gt_labels': label}) return results