1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| from PIL import Image import numpy as np import os import cv2 from torch.utils.data import Dataset from torch.utils.data import DataLoader import torchvision.transforms as transforms
from voc import VOCDataset
class VOCCLassifyDataset(Dataset):
def __init__(self, root_dir, image_set='train', transform=None): super(VOCCLassifyDataset, self).__init__()
voc_dataset = VOCDataset(root_dir, image_set=image_set) item_list = list() for idx in range(len(voc_dataset)): _, target = voc_dataset.__getitem__(idx) folder_name, img_name, objects = self.parse_target(target)
img_path = os.path.join(root_dir, 'VOCdevkit', folder_name, 'JPEGImages', img_name) for obj in objects: name = obj['name'] cate_idx = voc_dataset.cate_list.index(name)
xmin = obj['bndbox']['xmin'] ymin = obj['bndbox']['ymin'] xmax = obj['bndbox']['xmax'] ymax = obj['bndbox']['ymax']
difficult = obj['difficult']
if int(difficult) == 1: continue item_list.append( {'idx': idx, 'img_path': img_path, 'cate': name, 'cate_idx': cate_idx, 'bndbox': [int(xmin), int(ymin), int(xmax), int(ymax)]})
self.transform = transform self.voc_dataset = voc_dataset self.item_list = item_list
def __getitem__(self, idx): assert idx < len(self.item_list), 'the total num is %d' % len(self.item_list)
item_dict = self.item_list[idx] image, _ = self.voc_dataset.__getitem__(item_dict['idx']) xmin, ymin, xmax, ymax = item_dict['bndbox'] cate_idx = item_dict['cate_idx']
image = np.array(image) image = image[ymin:ymax, xmin:xmax]
if self.transform: image = Image.fromarray(image) image = self.transform(image)
return image, cate_idx
def __len__(self): return len(self.item_list)
def parse_target(self, target): folder_name = target['annotation']['folder'] img_name = target['annotation']['filename']
objects = target['annotation']['object']
return folder_name, img_name, objects
def test(): dataset = VOCCLassifyDataset('./data', image_set='test')
image, target = dataset.__getitem__(66) print(image.shape) print(target)
cv2.imshow('img', image) cv2.waitKey(0)
def test2(): transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ])
dataset = VOCCLassifyDataset('./data', image_set='trainval', transform=transform) print('dataset num:', len(dataset)) dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
item = next(iter(dataloader)) inputs, targets = item print(inputs.shape) print(targets)
if __name__ == '__main__': for name in ['train', 'val', 'trainval', 'test']: dataset = VOCCLassifyDataset('./data', image_set=name)
print('{} - 图像数: {} 目标数: {}'.format(name, len(dataset.voc_dataset), len(dataset)))
Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar train - 图像数: 8218 目标数: 19910 Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar val - 图像数: 8333 目标数: 20148 Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar trainval - 图像数: 16551 目标数: 40058 Using downloaded and verified file: ./data/VOCtest_06-Nov-2007.tar test - 图像数: 4952 目标数: 12032
|