代码拉取完成,页面将自动刷新
import logging
from functools import partial
import tensorflow as tf
from inputs import gpt2_pred_input
from models.gpt2 import encoder
# Takes in the user supplied text and generates output texts. Outputs to log/console and a file
def gpt2_predict(network, text, params):
logger = logging.getLogger('tensorflow')
enc = encoder.get_encoder(params["encoder_path"])
predictions = network.predict(input_fn=partial(gpt2_pred_input, text=text))
with tf.gfile.Open(params["predict_path"], "a") as f:
for i, p in enumerate(predictions):
p = p["tokens"]
text = enc.decode(p)
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
f.write(text)
f.write("\n" + "=" * 80 + "\n")
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
logger.info(text)
logger.info("\n" + "=" * 80 + "\n")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。