代码拉取完成,页面将自动刷新
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
def load_data(args):
if args.dataset_mode == "CIFAR100":
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR100('data', train=True, download=True, transform=transform_train),
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR100('data', train=False, transform=transform_test),
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers
)
elif args.dataset_mode == "CIFAR10":
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True, transform=transform_train),
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transform_test),
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers
)
elif args.dataset_mode == "IMAGENET":
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
# Check class labels
# print(train_dataset.classes)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
sampler=train_sampler
)
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True
)
return train_loader, test_loader
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。