加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 5.07 KB
一键复制 编辑 原始数据 按行查看 历史
lear 提交于 2021-02-19 10:21 . Add crnn scripts of pytorch
#!/usr/bin/python
# encoding: utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import collections
import numpy as np
from config import config1, label_dict
from PIL import Image, ImageFile
letters = [letter for letter in label_dict]
def labels_to_text(labels):
return ''.join(list(map(lambda x: letters[int(x)], labels)))
def text_to_labels(text):
return list(map(lambda x: letters.index(x.lower()), text))
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = np.array(img)
img = ((img / 255.0) - 0.5) / 0.5
img = img.astype(np.float32)
img = np.transpose(img, (2, 0, 1))
return img
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, label_dict=label_dict, ignore_case=True, config=config1):
self._ignore_case = ignore_case
if self._ignore_case:
label_dict = label_dict.lower()
self.alphabet = label_dict
self.dict = {}
for i, char in enumerate(label_dict):
self.dict[char] = i
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = [
self.dict[char.lower() if self._ignore_case else char]
for char in text
]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
class averager(object):
"""Compute average for `torch.Variable` and `torch.Tensor`. """
def __init__(self):
self.reset()
def add(self, v):
if isinstance(v, Variable):
count = v.data.numel()
v = v.data.sum()
elif isinstance(v, torch.Tensor):
count = v.numel()
v = v.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def oneHot(v, v_length, nc):
batchSize = v_length.size(0)
maxLength = v_length.max()
v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
acc = 0
for i in range(batchSize):
length = v_length[i]
label = v[acc:acc + length].view(-1, 1).long()
v_onehot[i, :length].scatter_(1, label, 1.0)
acc += length
return v_onehot
def loadData(v, data):
v.data.copy_(data)
def prettyPrint(v):
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
v.mean().data[0]))
def assureRatio(img):
"""Ensure imgH <= imgW."""
b, c, h, w = img.size()
if h > w:
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化