加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 1.81 KB
一键复制 编辑 原始数据 按行查看 历史
taliux 提交于 2024-01-09 10:45 . rename arguments
import random
import sys
from OpenAIUtils.PromptTemplate import PromptTemplate
from GA.Evaluator import Evaluator
from data_loader import *
from fitness import *
from arguments import get_args
from Logger import logger
from GA.GAOptimiser import GAOptimiser
def save_prompt_template(prompt_template, path):
f = open(path, 'w', encoding='utf-8')
f.write(str(prompt_template))
f.close()
def prepare_data(args):
raw_data = load_jsonl(args.data_path)
train_data, test_data = k_shot_split(raw_data, args.k_shot)
return train_data, test_data
def do_test(prompt_template, args, data, data_name=None):
evaluator = Evaluator(args)
predictions = evaluator.run(data, prompt_template)
if data_name:
logger.info("Evaluate on {}:".format(data_name.upper()))
score, extra = accuracy()(data, predictions)
logger.info(json.dumps(extra, indent=4))
def set_seed(seed):
random.seed(seed)
def main():
args = get_args()
set_seed(args.seed)
# prepare data
train_data, test_data = prepare_data(args)
# prepare initial prompt template
prompt_template = PromptTemplate().from_file(args.prompt_template)
# test the initial prompt template
logger.info("Before GA")
do_test(prompt_template, args, train_data, "train")
do_test(prompt_template, args, test_data, "test")
# run GA
logger.info("Start GA")
optimiser = GAOptimiser(args)
best_individual = optimiser.evolve(
prompt_template,
train_data,
[accuracy(), mean_hinge_probs()]
)
logger.info("Final scores on TRAIN: {}".format(best_individual.scores))
do_test(best_individual.to_prompt_template(), args, test_data, "test")
# save the optimised prompt
save_prompt_template(best_individual.to_prompt_template(), args.save)
if __name__ == "__main__":
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化