加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
similarity.py 10.78 KB
一键复制 编辑 原始数据 按行查看 历史
stopit 提交于 2019-10-24 19:38 . Add files via upload
"""
进行文本相似度预测的示例。可以直接运行进行预测。
参考了项目:https://github.com/chdd/bert-utils
"""
import tensorflow as tf
import args
import tokenization
import modeling
from run_classifier import InputFeatures, InputExample, DataProcessor, create_model, convert_examples_to_features
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
class SimProcessor(DataProcessor):
def get_sentence_examples(self, questions):
examples = []
for index, data in enumerate(questions):
guid = 'test-%d' % index
text_a = tokenization.convert_to_unicode(str(data[0]))
text_b = tokenization.convert_to_unicode(str(data[1]))
label = str(0)
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
return ['0', '1']
"""
模型类,负责载入checkpoint初始化模型
"""
class BertSim:
def __init__(self, batch_size=args.batch_size):
self.mode = None
self.max_seq_length = args.max_seq_len
self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
self.batch_size = batch_size
self.estimator = None
self.processor = SimProcessor()
tf.logging.set_verbosity(tf.logging.INFO)
#载入estimator,构造模型
def start_model(self):
self.estimator = self.get_estimator()
def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps,
use_one_hot_embeddings):
"""Returns `model_fn` closurimport_tfe for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
from tensorflow.python.estimator.model_fn import EstimatorSpec
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(total_loss, per_example_loss, logits, probabilities) = create_model(
bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
num_labels, use_one_hot_embeddings)
tvars = tf.trainable_variables()
initialized_variable_names = {}
if init_checkpoint:
(assignment_map, initialized_variable_names) \
= modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = EstimatorSpec(mode=mode, predictions=probabilities)
return output_spec
return model_fn
def get_estimator(self):
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
bert_config = modeling.BertConfig.from_json_file(args.config_name)
label_list = self.processor.get_labels()
if self.mode == tf.estimator.ModeKeys.TRAIN:
init_checkpoint = args.ckpt_name
else:
init_checkpoint = args.output_dir
model_fn = self.model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list),
init_checkpoint=init_checkpoint,
learning_rate=args.learning_rate,
num_train_steps=None,
num_warmup_steps=None,
use_one_hot_embeddings=False)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = args.gpu_memory_fraction
config.log_device_placement = False
return Estimator(model_fn=model_fn, config=RunConfig(session_config=config), model_dir=args.output_dir,
params={'batch_size': self.batch_size})
def predict_sentences(self,sentences):
results= self.estimator.predict(input_fn=input_fn_builder(self,sentences), yield_single_examples=False)
#打印预测结果
for i in results:
print(i)
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def convert_single_example(self, ex_index, example, label_list, max_seq_length, tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example.label]
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % (example.guid))
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id)
return feature
def input_fn_builder(bertSim,sentences):
def predict_input_fn():
return (tf.data.Dataset.from_generator(
generate_from_input,
output_types={
'input_ids': tf.int32,
'input_mask': tf.int32,
'segment_ids': tf.int32,
'label_ids': tf.int32},
output_shapes={
'input_ids': (None, bertSim.max_seq_length),
'input_mask': (None, bertSim.max_seq_length),
'segment_ids': (None, bertSim.max_seq_length),
'label_ids': (1,)}).prefetch(10))
def generate_from_input():
processor = bertSim.processor
predict_examples = processor.get_sentence_examples(sentences)
features = convert_examples_to_features(predict_examples, processor.get_labels(), args.max_seq_len,
bertSim.tokenizer)
yield {
'input_ids': [f.input_ids for f in features],
'input_mask': [f.input_mask for f in features],
'segment_ids': [f.segment_ids for f in features],
'label_ids': [f.label_id for f in features]
}
return predict_input_fn
if __name__ == '__main__':
sim = BertSim()
sim.start_model()
sim.predict_sentences([("我喜欢妈妈做的汤", "妈妈做的汤我很喜欢喝")])
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化