加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
datasets.py 4.10 KB
一键复制 编辑 原始数据 按行查看 历史
pengzhiliang 提交于 2021-05-09 11:44 . init
import os
import json
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
class INatDataset(ImageFolder):
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
category='name', loader=default_loader):
self.transform = transform
self.loader = loader
self.target_transform = target_transform
self.year = year
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
with open(path_json) as json_file:
data = json.load(json_file)
with open(os.path.join(root, 'categories.json')) as json_file:
data_catg = json.load(json_file)
path_json_for_targeter = os.path.join(root, f"train{year}.json")
with open(path_json_for_targeter) as json_file:
data_for_targeter = json.load(json_file)
targeter = {}
indexer = 0
for elem in data_for_targeter['annotations']:
king = []
king.append(data_catg[int(elem['category_id'])][category])
if king[0] not in targeter.keys():
targeter[king[0]] = indexer
indexer += 1
self.nb_classes = len(targeter)
self.samples = []
for elem in data['images']:
cut = elem['file_name'].split('/')
target_current = int(cut[2])
path_current = os.path.join(root, cut[0], cut[2], cut[3])
categors = data_catg[target_current]
target_current_true = targeter[categors[category]]
self.samples.append((path_current, target_current_true))
# __getitem__ and __len__ inherited from ImageFolder
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'CIFAR10':
dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform)
nb_classes = 10
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'INAT':
dataset = INatDataset(args.data_path, train=is_train, year=2018,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
elif args.data_set == 'INAT19':
dataset = INatDataset(args.data_path, train=is_train, year=2019,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * args.input_size)
t.append(
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化