加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/ConnorJL/GPT2
克隆/下载
predict_fns.py 926 Bytes
一键复制 编辑 原始数据 按行查看 历史
connor 提交于 2019-05-23 07:45 . Lots of cleanup
import logging
from functools import partial
import tensorflow as tf
from inputs import gpt2_pred_input
from models.gpt2 import encoder
# Takes in the user supplied text and generates output texts. Outputs to log/console and a file
def gpt2_predict(network, text, params):
logger = logging.getLogger('tensorflow')
enc = encoder.get_encoder(params["encoder_path"])
predictions = network.predict(input_fn=partial(gpt2_pred_input, text=text))
with tf.gfile.Open(params["predict_path"], "a") as f:
for i, p in enumerate(predictions):
p = p["tokens"]
text = enc.decode(p)
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
f.write(text)
f.write("\n" + "=" * 80 + "\n")
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
logger.info(text)
logger.info("\n" + "=" * 80 + "\n")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化