加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
dataset.py 1.58 KB
一键复制 编辑 原始数据 按行查看 历史
GanZhenliang 提交于 2024-10-23 10:44 . commit
import os
import numpy as np
from glob import glob
from PIL import Image, ImageOps
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
class StegaData(Dataset):
def __init__(self, data_path, secret_size=100, size=(400, 400)):
self.data_path = data_path
self.secret_size = secret_size
self.size = size
self.files_list = glob(os.path.join(self.data_path, '*.jpg'))
self.to_tensor = transforms.ToTensor()
def __getitem__(self, idx):
img_cover_path = self.files_list[idx]
img_cover = Image.open(img_cover_path).convert('RGB')
img_cover = ImageOps.fit(img_cover, self.size)
img_cover = self.to_tensor(img_cover)
# img_cover = np.array(img_cover, dtype=np.float32) / 255.
secret = np.random.binomial(1, 0.5, self.secret_size)
secret = torch.from_numpy(secret).float()
return img_cover, secret
def __len__(self):
return len(self.files_list)
if __name__ == '__main__':
# dataset = StegaData(data_path='F:\\VOCdevkit\\VOC2012\\JPEGImages')
# print(len(dataset))
# img_cover, secret = dataset[10]
# print(type(img_cover), type(secret))
# print(img_cover.shape, secret.shape)
dataset = StegaData(data_path=r'E:\dataset\mirflickr', secret_size=100, size=(400, 400))
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, pin_memory=True)
image_input, secret_input = next(iter(dataloader))
print(type(image_input), type(secret_input))
print(image_input.shape, secret_input.shape)
print(image_input.max())
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化