代码拉取完成,页面将自动刷新
同步操作将从 baiHR17/nlp_research 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
#-*- coding:utf-8 -*-
from tests import tests
import yaml
import os,sys
import pdb
import time
import logging
ROOT_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-1])
sys.path.append(ROOT_PATH)
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
class Run():
def __init__(self, init_log = False):
if init_log:
self.init_logging('log')
def init_logging(self, logFilename):
logging.basicConfig(
level = logging.DEBUG,
#format = '%(asctime)s\t%(filename)s,line %(lineno)s\t%(levelname)s: %(message)s',
format = '%(asctime)s\t%(levelname)s: %(message)s',
datefmt = '%Y/%m/%d %H:%M:%S',
filename = logFilename,
filemode = 'w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s\t%(levelname)s: %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def _change_path(self, conf):
path_root = os.path.join(ROOT_PATH, conf['path_root'])
for k,v in conf.items():
if k.endswith('_path'):
conf[k] = os.path.join(path_root, conf[k])
def read_config_type(self, conf):
#读取config信息,对应不同的参数
if "config" and "config_type" in conf:
config_type = conf['config_type']
for k,v in (conf['config'][config_type]).items():
conf[k] = v
del conf['config']
def read_conf(self, conf_name):
base_yml = os.path.join(ROOT_PATH, "conf/model/base.yml")
task_yml = os.path.join(ROOT_PATH, "conf/model/{}".format(conf_name))
assert os.path.exists(task_yml),'yml conf [%s] does not exists!'%task_yml
conf = yaml.load(open(task_yml))
self.read_config_type(conf)
base = yaml.load(open(base_yml))
#相对路径->绝对路径
self._change_path(base)
self._change_path(conf)
#加载base信息
for k,v in base.items():
if k not in conf:
conf[k] = v
#更新encoder_type信息
if 'encoder_type' in conf:
for k,v in conf.items():
if type(v) == str and (v.find('{encoder_type}')) != -1:
conf[k] = v.replace("{encoder_type}", conf['encoder_type'])
#创建相关目录
model_path = '/'.join(conf['model_path'].split('/')[:-1])
if not os.path.exists(model_path):
os.makedirs(model_path)
if 'tfrecords_path' in conf:
tfrecords_path = conf['tfrecords_path']
if not os.path.exists(tfrecords_path):
os.makedirs(tfrecords_path)
#使用指令中的额外参数更新conf
if len(sys.argv) >1:
#additional params from cmd
for idx, arg in enumerate(sys.argv):
if idx ==0:continue
if arg.find("=") == -1:continue
key,value = arg.split('=')
if value.isdigit():value = int(value)
conf[key] = value
return conf
if __name__ == '__main__':
assert len(sys.argv) > 1,"task type missed, classify, match, ner...?"
conf_path = sys.argv[1]
run = Run(init_log = True)
conf = run.read_conf(conf_path)
logging.info(conf)
task_type = conf['task_type']
if 'prepare_data' in conf and conf['prepare_data'].lower() != 'false':
from tasks import tasks
cl = tasks[task_type](conf)
cl.prepare()
else:
if conf['mode'] == 'train': #训练
from tasks import tasks
cl = tasks[task_type](conf)
if hasattr(cl, "train_and_evaluate"):
cl.train_and_evaluate()
else:
cl.train()
cl.test('dev')
elif conf['mode'] == 'dev': #验证集测试
from tasks import tasks
cl = tasks[task_type](conf)
cl.test('dev')
elif conf['mode'] == 'test': #带标签测试
from tasks import tasks
cl = tasks[task_type](conf)
cl.test('test')
elif conf['mode'] == 'predict': #不带标签测试
conf['task_type'] = task_type
ts = tests[task_type](conf)
if 'test_path' in conf:
ts.test_file(conf['test_path'])
else:
ts.test_file(conf['ori_path'])
elif conf['mode'] in ['test_one','test_unit']: #单个测试
conf['task_type'] = task_type
ts = tests[task_type](conf)
while True:
a = input('input:')
start = time.time()
ret = ts(a)
print(ret)
end = time.time()
consume = end-start
print('consume: {}'.format(consume))
elif conf['mode'] == 'save': #保存模型到pb
from tasks import tasks
cl = tasks[task_type](conf)
cl.save()
else:
logging.error('unknown mode!')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。