加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 5.47 KB
一键复制 编辑 原始数据 按行查看 历史
赵泽伟 提交于 2017-11-26 10:58 . 1、修改了神经网络结构
import tensorflow as tf
import numpy as np
import code_utils
from model import Model
from image_utils import ImageUtils
import os
# tf.set_random_seed(1)
image = tf.placeholder(tf.float32, [None, 26, 70, 1]) # 定义图片的大小
# 定义每个预测值得维度
labels = dict(
digit1=tf.placeholder(tf.float32, [None, 36]),
digit2=tf.placeholder(tf.float32, [None, 36]),
digit3=tf.placeholder(tf.float32, [None, 36]),
digit4=tf.placeholder(tf.float32, [None, 36])
)
training_options = dict(
drop_rate=0.9,
learning_rate=1e-3, # 学习率
decay_steps=10000, # 多少步降低学习率
decay_rate=1, # 每次降低 1 - decay_rate
batch_size=32, # 每次训练多少张图片
show_loss=20, # 貌似没用到
total_episode=9999999999, # 总训练回合
show_test=1000, # 多少步展示一下测试数据的预测率以及预测值、真实值
output_board=True, # 是否输出到tensorboard
logs_step=10000, # 多少步往tensorboard里写入 同时存放model
save_step=10000,
log_path="logs/", # tensorboard的log文件存放在哪里
model_path='net_model/', # model保存在那个文件夹
model_name='model.ckpt' # model文件名
)
model = Model() # 初始化model类
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
# 保证存放model的文件夹存在
if not os.path.exists(training_options['model_path']):
os.mkdir(training_options['model_path'])
# 定义神经网络
net, train = model.build_network(training_options=training_options, image=image,
drop_rate=training_options['drop_rate'], labels=labels)
# 如果写入tensorboard
if training_options['output_board']:
merged = tf.summary.merge_all() # tensorflow >= 0.12
writer = tf.summary.FileWriter(training_options['log_path'], sess.graph) # tensorflow >=0.12
# 获得saver对象,可以保存model以及读取model
saver = tf.train.Saver()
# 如果已经存在model副本就直接读取,否则就初始化神经网络参数
if os.path.exists(training_options['model_path'] + '/checkpoint'):
saver.restore(sess, training_options['model_path'] + training_options['model_name'])
else:
sess.run(tf.global_variables_initializer())
# 初始化 imageUtils类,获得所有训练,测试数据
imageUtils = ImageUtils()
train_data = imageUtils.train_data
train_label = imageUtils.train_label
test_data = imageUtils.test_data
test_label = imageUtils.test_label
# 开始训练啦
for episode in range(training_options['total_episode']):
# 随机获得定义好数量的训练数据
sample_datas, sample_labels = imageUtils.sample(len(train_data), training_options['batch_size'],
train_data, train_label)
# 训练神经网络
global_step, _, loss = sess.run(
[train['global_step'], train['train'], train['loss']],
feed_dict={image: sample_datas, labels['digit1']: sample_labels[:, 0],
labels['digit2']: sample_labels[:, 1],
labels['digit3']: sample_labels[:, 2],
labels['digit4']: sample_labels[:, 3]})
# 打印 [总共训练了多少回合, loss值]
print('total episode: {0:}\t\tloss: {1:.4f}'.format(global_step, loss))
# 指定的回合数时保存tensorboard log文件,保存model(神经网络参数)
if episode % training_options['logs_step'] == 0 and training_options['output_board']:
result = sess.run(merged, feed_dict={image: sample_datas, labels['digit1']: sample_labels[:, 0],
labels['digit2']: sample_labels[:, 1],
labels['digit3']: sample_labels[:, 2],
labels['digit4']: sample_labels[:, 3]})
writer.add_summary(result, global_step)
if episode != 0 and episode % training_options['save_step'] == 0:
saver.save(sess, training_options['model_path'] + training_options['model_name'])
# 指定回合数时 打印预测值、真实值、正确率
if episode % training_options['show_test'] == 0:
t_sample_datas, t_sample_labels = imageUtils.sample(len(test_data), 10, test_data, test_label)
result = sess.run([net['digit1'], net['digit2'], net['digit3'], net['digit4']],
feed_dict={image: t_sample_datas})
result = code_utils.batch_out_transition(result)
predicted = [result[0][index] + result[1][index] + result[2][index] + result[3][index]
for index in range(len(result[0]))]
label = code_utils.batch_out_transition(t_sample_labels)
four_right_count = np.count_nonzero([predicted[index] == label[index] for index in range(len(predicted))])
one_right_count = np.count_nonzero(
[predicted[index][s_index] == label[index][s_index] for index in range(len(predicted)) for s_index in
range(len(predicted[index]))])
print('predicted:\t{} \nlabel:\t\t{}'.format(predicted,
label))
print('4 match: {0:.2f}%\t\t1 match: {1:.2f}%'.format(four_right_count / 10 * 100,
one_right_count / 40 * 100))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化