加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
text_model.py 4.97 KB
一键复制 编辑 原始数据 按行查看 历史
Jinchengxin 提交于 2020-05-16 13:21 . 完成第四次训练
#encoding:utf-8
import tensorflow as tf
class TextConfig():
embedding_size=100 #dimension of word embedding
vocab_size=23051 #number of vocabulary
pre_trianing = None #use vector_char trained by word2vec
seq_length=256 #max length of sentence
num_classes=34 #number of labels
num_filters=128 #number of convolution kernel
filter_sizes=[2,3,4] #size of convolution kernel
keep_prob=0.5 #droppout
lr= 1e-3 #learning rate
lr_decay= 0.9 #learning rate decay
clip= 6.0 #gradient clipping threshold
l2_reg_lambda=0.01 #l2 regularization lambda
num_epochs=30 #epochs
batch_size=64 #batch_size
print_per_batch =100 #print result
train_filename='./data/v2.0Train.txt' #train data
test_filename='./data/v2.0Test.txt' #test data
val_filename='./data/v2.0Val.txt' #validation data
vocab_filename='./data/vocab.txt' #vocabulary
vector_word_filename='./data/vector_word.txt' #vector_word trained by word2vec
vector_word_npz='./data/vector_word.npz' # save vector_word to numpy file
class TextCNN(object):
def __init__(self,config):
self.config=config
self.input_x=tf.placeholder(tf.int32,shape=[None,self.config.seq_length],name='input_x')
self.input_y=tf.placeholder(tf.float32,shape=[None,self.config.num_classes],name='input_y')
self.keep_prob=tf.placeholder(tf.float32,name='dropout')
self.global_step = tf.Variable(0, trainable=False, name='global_step')
self.l2_loss = tf.constant(0.0)
self.cnn()
def cnn(self):
with tf.device('/cpu:0'):
self.embedding = tf.get_variable("embeddings", shape=[self.config.vocab_size, self.config.embedding_size],
initializer=tf.constant_initializer(self.config.pre_trianing))
self.embedding_inputs= tf.nn.embedding_lookup(self.embedding, self.input_x)
self.embedding_inputs_expanded = tf.expand_dims(self.embedding_inputs, -1)
with tf.name_scope('cnn'):
pooled_outputs = []
for i, filter_size in enumerate(self.config.filter_sizes):
with tf.name_scope("conv-maxpool-%s" % filter_size):
filter_shape = [filter_size, self.config.embedding_size, 1, self.config.num_filters]
W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[self.config.num_filters]), name="b")
conv = tf.nn.conv2d(
self.embedding_inputs_expanded,
W,
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
pooled = tf.nn.max_pool(
h,
ksize=[1, self.config.seq_length - filter_size + 1, 1, 1],
strides=[1, 1, 1, 1],
padding='VALID',
name="pool")
pooled_outputs.append(pooled)
num_filters_total = self.config.num_filters * len(self.config.filter_sizes)
self.h_pool = tf.concat(pooled_outputs, 3)
self.outputs= tf.reshape(self.h_pool, [-1, num_filters_total])
with tf.name_scope("dropout"):
self.final_output = tf.nn.dropout(self.outputs, self.keep_prob)
with tf.name_scope('output'):
fc_w = tf.get_variable('fc_w', shape=[self.final_output.shape[1].value, self.config.num_classes],
initializer=tf.contrib.layers.xavier_initializer())
fc_b = tf.Variable(tf.constant(0.1, shape=[self.config.num_classes]), name='fc_b')
self.logits = tf.matmul(self.final_output, fc_w) + fc_b
self.prob=tf.nn.softmax(self.logits)
self.y_pred_cls = tf.argmax(self.logits, 1, name='predictions')
with tf.name_scope('loss'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
self.l2_loss += tf.nn.l2_loss(fc_w)
self.l2_loss += tf.nn.l2_loss(fc_b)
self.loss = tf.reduce_mean(cross_entropy) + self.config.l2_reg_lambda * self.l2_loss
self.loss = tf.reduce_mean(cross_entropy)
with tf.name_scope('optimizer'):
optimizer = tf.train.AdamOptimizer(self.config.lr)
gradients, variables = zip(*optimizer.compute_gradients(self.loss))
gradients, _ = tf.clip_by_global_norm(gradients, self.config.clip)
self.optim = optimizer.apply_gradients(zip(gradients, variables), global_step=self.global_step)
with tf.name_scope('accuracy'):
correct_pred=tf.equal(tf.argmax(self.input_y,1),self.y_pred_cls)
self.acc=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化