加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
tfrecords.py 3.61 KB
一键复制 编辑 原始数据 按行查看 历史
宝贝龙 提交于 2019-09-17 21:57 . sss
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image,ImageFilter
import numpy as np
import matplotlib.pyplot as plt
import re
type2 = input("请输入格式化模式:1=训练样本,2=测试样本")
print('type',type(type2))
if type2 == '1':
imgPath = './data/img/'
filename = './data/train.tfrecords'
else:
imgPath = './data/testimg/'
filename = './data/test.tfrecords'
#文件名格式 时间-目标坐标x-目标坐标y-黑人物坐标x-黑人物坐标y.png
def getLabel(name):
m = re.match( r'(.*?)-(.*?)-(.*?)-(.*?)-(.*?).png', name, re.M|re.I)
if m:
label = np.array([m.group(2),m.group(3),m.group(4),m.group(5)]).astype(int).tolist()
return label
else:
print("No match!!")
#制作二进制数据
def create_record():
writer = tf.python_io.TFRecordWriter(filename)
for i in os.listdir(imgPath):
img = Image.open(imgPath+i)
img = img.convert('L')
img_raw = img.tobytes() #将图片转化为原生bytes
label = getLabel(i)
label_raw = bytes(label)
example = tf.train.Example(
features=tf.train.Features(feature={
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_raw])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
data = create_record()
#读取二进制数据
def read_and_decode():
# 创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.string),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [224, 224, 1])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.decode_raw(label, tf.uint8)
label = tf.reshape(label, [4])
label = tf.cast(label, tf.float32)
return img, label
img, label = read_and_decode()
print("tengxing",img,label)
#使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
# shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
# 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
# shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
# Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=1, capacity=2000,
min_after_dequeue=1000)
# 初始化所有的op
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
fig = plt.figure()
for i in range(1):
val, l = sess.run([img_batch, label_batch])
print(val.shape, l.shape)
plt.subplot(1,1,i+1)
plt.imshow(val[0].reshape(224,224),cmap='gray')
plt.show()
coord.request_stop()
coord.join(threads)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化