代码拉取完成,页面将自动刷新
同步操作将从 taliux/prompt-tune 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import random
import sys
from OpenAIUtils.PromptTemplate import PromptTemplate
from GA.Evaluator import Evaluator
from data_loader import *
from fitness import *
from arguments import get_args
from Logger import logger
from GA.GAOptimiser import GAOptimiser
def save_prompt_template(prompt_template, path):
f = open(path, 'w', encoding='utf-8')
f.write(str(prompt_template))
f.close()
def prepare_data(args):
raw_data = load_jsonl(args.data_path)
train_data, test_data = k_shot_split(raw_data, args.k_shot)
return train_data, test_data
def do_test(prompt_template, args, data, data_name=None):
evaluator = Evaluator(args)
predictions = evaluator.run(data, prompt_template)
if data_name:
logger.info("Evaluate on {}:".format(data_name.upper()))
score, extra = accuracy()(data, predictions)
logger.info(json.dumps(extra, indent=4))
def set_seed(seed):
random.seed(seed)
def main():
args = get_args()
set_seed(args.seed)
# prepare data
train_data, test_data = prepare_data(args)
# prepare initial prompt template
prompt_template = PromptTemplate().from_file(args.prompt_template)
# test the initial prompt template
logger.info("Before GA")
do_test(prompt_template, args, train_data, "train")
do_test(prompt_template, args, test_data, "test")
# run GA
logger.info("Start GA")
optimiser = GAOptimiser(args)
best_individual = optimiser.evolve(
prompt_template,
train_data,
[accuracy(), mean_hinge_probs()]
)
logger.info("Final scores on TRAIN: {}".format(best_individual.scores))
do_test(best_individual.to_prompt_template(), args, test_data, "test")
# save the optimised prompt
save_prompt_template(best_individual.to_prompt_template(), args.save)
if __name__ == "__main__":
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。