加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
text_loader.py 4.47 KB
一键复制 编辑 原始数据 按行查看 历史
qiao1025566574 提交于 2021-09-23 08:21 . generate
from pathlib import Path
from random import randint, choice
import torch
import PIL
from PIL import Image
Image.MAX_IMAGE_PIXELS = 2300000000
from torch.utils.data import Dataset
# from torchvision import transforms as T
#
# from dalle_pytorch import distributed_utils
#
# import numpy as np
from tqdm import tqdm
import os
import re
import pickle
import zlib
from torch.utils.data import DataLoader
class TextDataset(Dataset):
def __init__(self,
text_file=None,
num_images=32,
text_len=256,
tokenizer=None,
shuffle=False,
backend=None
):
"""
@param folder: Folder containing images and text files matched by their paths' respective "stem"
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
"""
super().__init__()
self.shuffle = shuffle
print('self.shuffle', self.shuffle)
# get keys
# keys = []
print('getting text from file', text_file)
with open(text_file, encoding='utf-8') as text_file_fp:
text_file_lines = text_file_fp.readlines()
texts = [line.strip() for line in tqdm(text_file_lines) if line.strip() != ""]
print("len(texts)", len(texts))
texts_tokens = tokenizer.tokenize(texts, text_len)
print("texts_token.shape", texts_tokens.shape)
keys = range(0, len(texts)*num_images)
self.keys = keys
self.text_len = text_len
self.tokenizer = tokenizer
self.backend = backend
self.texts = texts
self.num_images = num_images
self.texts_tokens = texts_tokens
def __len__(self):
return len(self.keys)
def random_sample(self):
return self.__getitem__(randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def get_text(self):
return self.texts if self.texts else None
def __getitem__(self, ind):
key = self.keys[ind]
text = self.texts[key // self.num_images]
text_token = self.texts_tokens[key // self.num_images]
index = key % self.num_images
return text, text_token, index
def getfiles(self, dirPath, type='.*\.txt'):
fileList = []
# open directory
if dirPath is None:
return fileList
files = os.listdir(dirPath)
# re match *.xls/xlsx,you can change 'xlsx' to 'doc' or other file types.
ptn = re.compile(type)
for f in files:
# isdir, call self
# no more
# if (os.path.isdir(dirPath + '\\' + f)):
# self.getfiles(dirPath + '\\' + f)
# isfile, judge
if (os.path.isfile(os.path.join(dirPath, f))):
res = ptn.match(f)
if (res != None):
fileList.append(os.path.join(dirPath, res.group()))
# else:
# fileList.append(dirPath + '/无效文件')os.path.join(dirPath, '无效文件')
return fileList
def hash(self, url):
return zlib.crc32(url.encode('utf-8')) & 0xffffffff
if __name__ == '__main__':
text_file = "/share/home/ai_chenzhiyang/DALLE-trainfromcode/test_texts/test_text_file.txt"
num_images = 3
text_seq_len = 256
BATCH_SIZE=5
is_shuffle=True
from tokenizer import tokenizer
TextDataset = TextDataset(text_file=text_file,
num_images=num_images,
text_len=text_seq_len,
tokenizer=tokenizer,
shuffle=False,
backend=None)
data_sampler = torch.utils.data.distributed.DistributedSampler(
TextDataset,
num_replicas=1,
rank=0,
shuffle=is_shuffle,
)
data_loader = DataLoader(TextDataset,
batch_size=BATCH_SIZE,
drop_last=False,
sampler=data_sampler,
shuffle=False)
data_sampler.set_epoch(1)
for i, (texts, text_token, indexs) in enumerate(data_loader):
print("i", i)
print("texts", texts)
print("indexs", indexs)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化