# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import xml.etree.ElementTree as ET
from multiprocessing import cpu_count
import numpy as np
from easycv.datasets.detection.data_sources.base import DetSourceBase
from easycv.datasets.registry import DATASOURCES
from easycv.datasets.utils.download_data.download_voc import (
check_data_exists, download_voc)
from easycv.file import io
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
[docs]def parse_xml(source_item, classes):
img_path, xml_path = source_item
with io.open(xml_path, 'r') as f:
tree = ET.parse(f)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
gt_bboxes = []
gt_labels = []
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if int(difficult) == 1:
continue
if cls not in classes:
logging.warning(
'class: %s not in given class list, skip the object!' %
cls)
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
box = (float(xmlbox.find('xmin').text),
float(xmlbox.find('ymin').text),
float(xmlbox.find('xmax').text),
float(xmlbox.find('ymax').text))
gt_bboxes.append(box)
gt_labels.append(cls_id)
if len(gt_bboxes) == 0:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
img_info = {
'gt_bboxes': np.array(gt_bboxes, dtype=np.float32),
'gt_labels': np.array(gt_labels, dtype=np.int64),
'filename': img_path,
}
return img_info
[docs]@DATASOURCES.register_module
class DetSourceVOC(DetSourceBase):
"""
data dir is as follows:
```
|- voc_data
|-ImageSets
|-Main
|-train.txt
|-...
|-JPEGImages
|-00001.jpg
|-...
|-Annotations
|-00001.xml
|-...
```
Example1:
data_source = DetSourceVOC(
path='/your/voc_data/ImageSets/Main/train.txt',
classes=${VOC_CLASSES},
)
Example1:
data_source = DetSourceVOC(
path='/your/voc_data/train.txt',
classes=${VOC_CLASSES},
img_root_path='/your/voc_data/images',
img_root_path='/your/voc_data/annotations'
)
"""
[docs] def __init__(self,
path,
classes=[],
img_root_path=None,
label_root_path=None,
cache_at_init=False,
cache_on_the_fly=False,
img_suffix='.jpg',
label_suffix='.xml',
parse_fn=parse_xml,
num_processes=int(cpu_count() / 2),
**kwargs):
"""
Args:
path: path of img id list file in ImageSets/Main/
classes: classes list
img_root_path: image dir path, if None, default to detect the image dir by the relative path of the `path`
according to the VOC data format.
label_root_path: label dir path, if None, default to detect the label dir by the relative path of the `path`
according to the VOC data format.
cache_at_init: if set True, will cache in memory in __init__ for faster training
cache_on_the_fly: if set True, will cache in memroy during training
img_suffix: suffix of image file
label_suffix: suffix of label file
parse_fn: parse function to parse item of source iterator
num_processes: number of processes to parse samples
"""
self.path = path
self.img_root_path = img_root_path
self.label_root_path = label_root_path
self.img_suffix = img_suffix
self.label_suffix = label_suffix
super(DetSourceVOC, self).__init__(
classes=classes,
cache_at_init=cache_at_init,
cache_on_the_fly=cache_on_the_fly,
parse_fn=parse_fn,
num_processes=num_processes)
[docs] def get_source_iterator(self):
if not self.img_root_path:
self.img_root_path = os.path.join(
self.path.split('ImageSets/Main')[0], 'JPEGImages')
if not self.label_root_path:
self.label_root_path = os.path.join(
self.path.split('ImageSets/Main')[0], 'Annotations')
imgs_path_list = []
labels_path_list = []
with io.open(self.path, 'r') as t:
id_lines = t.read().splitlines()
for id_line in id_lines:
img_id = id_line.strip().split(' ')[0]
img_path = os.path.join(self.img_root_path,
img_id + self.img_suffix)
imgs_path_list.append(img_path)
label_path = os.path.join(self.label_root_path,
img_id + self.label_suffix)
labels_path_list.append(label_path)
return list(zip(imgs_path_list, labels_path_list))
[docs]@DATASOURCES.register_module
class DetSourceVOC2012(DetSourceVOC):
[docs] def __init__(self,
path=None,
download=True,
split='train',
classes=[],
img_root_path=None,
label_root_path=None,
cache_at_init=False,
cache_on_the_fly=False,
img_suffix='.jpg',
label_suffix='.xml',
parse_fn=parse_xml,
num_processes=int(cpu_count() / 2),
**kwargs):
"""
Args:
path: This parameter is optional. If download is True and path is not provided,
a temporary directory is automatically created for downloading
download: If the value is True, the file is automatically downloaded to the path directory.
If False, automatic download is not supported and data in the path is used
split: train or val
classes: classes list
img_root_path: image dir path, if None, default to detect the image dir by the relative path of the `path`
according to the VOC data format.
label_root_path: label dir path, if None, default to detect the label dir by the relative path of the `path`
according to the VOC data format.
cache_at_init: if set True, will cache in memory in __init__ for faster training
cache_on_the_fly: if set True, will cache in memroy during training
img_suffix: suffix of image file
label_suffix: suffix of label file
parse_fn: parse function to parse item of source iterator
num_processes: number of processes to parse samples
"""
# Check to see if you need to download it
if download:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = download_voc(
'voc2012', split=split, target_dir=path)['path']
else:
path = download_voc('voc2012', split=split)['path']
else:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = check_data_exists('voc2012', path, split)['path']
else:
raise KeyError('your path is None')
super(DetSourceVOC2012, self).__init__(
path=path,
classes=classes,
img_root_path=img_root_path,
label_root_path=label_root_path,
cache_at_init=cache_at_init,
cache_on_the_fly=cache_on_the_fly,
img_suffix=img_suffix,
label_suffix=label_suffix,
parse_fn=parse_fn,
num_processes=num_processes)
[docs]@DATASOURCES.register_module
class DetSourceVOC2007(DetSourceVOC):
[docs] def __init__(self,
path=None,
download=True,
split='train',
classes=[],
img_root_path=None,
label_root_path=None,
cache_at_init=False,
cache_on_the_fly=False,
img_suffix='.jpg',
label_suffix='.xml',
parse_fn=parse_xml,
num_processes=int(cpu_count() / 2),
**kwargs):
"""
Args:
path: This parameter is optional. If download is True and path is not provided,
a temporary directory is automatically created for downloading
download: If the value is True, the file is automatically downloaded to the path directory.
If False, automatic download is not supported and data in the path is used
split: train or val
classes: classes list
img_root_path: image dir path, if None, default to detect the image dir by the relative path of the `path`
according to the VOC data format.
label_root_path: label dir path, if None, default to detect the label dir by the relative path of the `path`
according to the VOC data format.
cache_at_init: if set True, will cache in memory in __init__ for faster training
cache_on_the_fly: if set True, will cache in memroy during training
img_suffix: suffix of image file
label_suffix: suffix of label file
parse_fn: parse function to parse item of source iterator
num_processes: number of processes to parse samples
"""
# Check to see if you need to download it
if download:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = download_voc(
'voc2007', split=split, target_dir=path)['path']
else:
path = download_voc('voc2007', split=split)['path']
else:
if path:
assert os.path.isdir(path), f'{path} is not dir'
path = check_data_exists('voc2007', path, split)['path']
else:
raise KeyError('your path is None')
super(DetSourceVOC2007, self).__init__(
path=path,
classes=classes,
img_root_path=img_root_path,
label_root_path=label_root_path,
cache_at_init=cache_at_init,
cache_on_the_fly=cache_on_the_fly,
img_suffix=img_suffix,
label_suffix=label_suffix,
parse_fn=parse_fn,
num_processes=num_processes)