Pytorch实现VOC数据集的Dataset

Pascal VOC2012 数据集下载地址:

http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

代码

import os
import torch
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from config import Config
import numpy as np
from PIL import Image

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])


class VOCDataset(Dataset):

    def __init__(self, data_dir, train=True, transform=None):
        super(VOCDataset, self).__init__()
        # 获取txt文件
        self.data_dir = data_dir
        if (train):
            split =  trainval 
        else:
            split =  val 
        id_list_file = os.path.join(self.data_dir,  ImageSets/Main/{0}.txt .format(split))
        self.ids = [id_.strip() for id_ in open(id_list_file)]

        self.transform = transform

    def __getitem__(self, item):
        id = self.ids[item]
        # 解析xml文件得到图片的bbox, label
        anno = ET.parse(
            os.path.join(self.data_dir,  Annotations , id +  .xml ))

        bbox = []
        label = []
        for obj in anno.findall( object ):

            bndbox_anno = obj.find( bndbox )
            box = []
            for tag in ( ymin ,  xmin ,  ymax ,  xmax ):
                box.append(int(bndbox_anno.find(tag).text) - 1)
            bbox.append(box)

            name = obj.find( name ).text.lower().strip()
            label.append(Config.VOC_BBOX_LABEL_NAMES.index(name))

        bbox = np.stack(bbox).astype(np.float32)
        label = np.stack(label).astype(np.float32)

        # 获取对应图片
        img_file = os.path.join(self.data_dir,  JPEGImages , id +  .jpg )
        img = Image.open(img_file)
        if self.transform:
            img = self.transform(img)
        if img.ndim == 2:
            img = img[np.newaxis]
        # (H,W,C)->(C,H,W)
        img = img.transpose(2, 0)

        return img, bbox, label

    def __len__(self):
        return len(self.ids)


if __name__ ==  __main__ :
    dataset = VOCDataset(data_dir=Config.voc_data_dir, train=True, transform=image_transform)
    data_loader = DataLoader(dataset, batch_size=1)
    for idx, (image, bbox, lable) in enumerate(data_loader):
        print (bbox)

常量文件 config.py

class Config:
    voc_data_dir =  VOCdevkit/VOC2012 

    VOC_BBOX_LABEL_NAMES = (
         aeroplane ,
         bicycle ,
         bird ,
         boat ,
         bottle ,
         bus ,
         car ,
         cat ,
         chair ,
         cow ,
         diningtable ,
         dog ,
         horse ,
         motorbike ,
         person ,
         pottedplant ,
         sheep ,
         sofa ,
         train ,
         tvmonitor 
    )

© 版权声明
THE END
如果内容对您有所帮助,就支持一下吧!
点赞0 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容