Source code for easycv.datasets.detection.data_sources.raw

# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
import logging
import os
from multiprocessing import cpu_count

import numpy as np

from easycv.core.bbox.bbox_util import batched_cxcywh2xyxy_with_shape
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from .base import DetSourceBase

img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
label_formats = ['.txt']


[docs]def parse_raw(source_iter, classes=None, delimeter=' '): img_path, label_path = source_iter source_info = {'filename': img_path} with io.open(label_path, 'r') as f: labels_and_boxes = np.array( [line.split(delimeter) for line in f.read().splitlines()]) if not len(labels_and_boxes): return source_info labels = labels_and_boxes[:, 0] bboxes = labels_and_boxes[:, 1:] source_info.update({ 'gt_bboxes': np.array(bboxes, dtype=np.float32), 'gt_labels': labels.astype(np.int64) }) return source_info
[docs]@DATASOURCES.register_module class DetSourceRaw(DetSourceBase): """ data dir is as follows: ``` |- data_dir |-images |-1.jpg |-... |-labels |-1.txt |-... ``` Label txt file is as follows: The first column is the label id, and columns 2 to 5 are coordinates relative to the image width and height [x_center, y_center, bbox_w, bbox_h]. ``` 15 0.519398 0.544087 0.476359 0.572061 2 0.501859 0.820726 0.996281 0.332178 ... ``` Example: data_source = DetSourceRaw( img_root_path='/your/data_dir/images', label_root_path='/your/data_dir/labels', ) """
[docs] def __init__(self, img_root_path, label_root_path, classes=[], cache_at_init=False, cache_on_the_fly=False, delimeter=' ', parse_fn=parse_raw, num_processes=int(cpu_count() / 2), **kwargs): """ Args: img_root_path: images dir path label_root_path: labels dir path classes(list, optional): classes list cache_at_init: if set True, will cache in memory in __init__ for faster training cache_on_the_fly: if set True, will cache in memroy during training delimeter: delimeter of txt file parse_fn: parse function to parse item of source iterator num_processes: number of processes to parse samples """ self.delimeter = delimeter self.img_root_path = img_root_path self.label_root_path = label_root_path parse_fn = functools.partial(parse_fn, delimeter=delimeter) super(DetSourceRaw, self).__init__( classes=classes, cache_at_init=cache_at_init, cache_on_the_fly=cache_on_the_fly, parse_fn=parse_fn, num_processes=num_processes)
[docs] def get_source_iterator(self): self.img_files = [ os.path.join(self.img_root_path, i) for i in io.listdir(self.img_root_path, recursive=True) if os.path.splitext(i)[-1].lower() in img_formats ] self.label_files = [] self.img_files_effec = [] for img_path in self.img_files: img_name = os.path.splitext(os.path.basename(img_path))[0] find_label_path = False for label_format in label_formats: lable_path = os.path.join(self.label_root_path, img_name + label_format) if io.exists(lable_path): find_label_path = True self.label_files.append(lable_path) self.img_files_effec.append(img_path) break if not find_label_path: logging.warning( 'Not find label file %s for img: %s, skip the sample!' % (lable_path, img_path)) assert len(self.img_files_effec) == len(self.label_files) assert len(self.img_files_effec ) > 0, 'No samples found in %s' % self.img_root_path return list(zip(self.img_files_effec, self.label_files))
[docs] def post_process_fn(self, result_dict): result_dict = super(DetSourceRaw, self).post_process_fn(result_dict) result_dict['gt_bboxes'] = batched_cxcywh2xyxy_with_shape( result_dict['gt_bboxes'], shape=result_dict['img_shape'][:2]) return result_dict
[docs] def get_ann_info(self, idx): """ Get raw annotation info, include bounding boxes, labels and so on. `bboxes` format is as [x1, y1, x2, y2] without normalization. """ sample_info = self.samples_list[idx] result_dict = self[idx] groundtruth_is_crowd = sample_info.get('groundtruth_is_crowd', None) if groundtruth_is_crowd is None: groundtruth_is_crowd = np.zeros_like(sample_info['gt_labels']) annotations = { 'bboxes': result_dict['gt_bboxes'], 'labels': sample_info['gt_labels'], 'groundtruth_is_crowd': groundtruth_is_crowd } return annotations