加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
logger.py 1.65 KB
一键复制 编辑 原始数据 按行查看 历史
程磊 提交于 2022-04-01 15:37 . Initial commit
import torch
from config import cfg
import os
import json
import numpy as np
class MetricsRecorder():
def __init__(self):
self.rec = {}
def add(self, pairs):
for key, val in pairs.items():
if key not in self.rec:
self.rec[key] = []
self.rec[key].append(val)
def mean(self):
r = {}
for key, val in self.rec.items():
r[key] = np.mean(val)
return r
class Logger():
def __init__(self):
self.base_path = './logs/' + cfg.base.task_name
self.logfile = self.base_path + '/log.json'
self.cfgfile = self.base_path + '/cfg.json'
if not os.path.isdir(self.base_path):
os.makedirs(self.base_path, exist_ok=True)
with open(self.logfile, 'w') as fp:
json.dump({}, fp)
with open(self.cfgfile, 'w') as fp:
json.dump(cfg, fp)
def save_record(self, epoch, record):
with open(self.logfile) as fp:
log = json.load(fp)
log[str(epoch)] = record
with open(self.logfile, 'w') as fp:
json.dump(log, fp)
def save_network(self, epoch, network):
saving_path = self.base_path + '/ckp.%d.torch' % epoch
print('saving model ...')
if type(network) is torch.nn.DataParallel:
torch.save(network.module.state_dict(), saving_path)
else:
torch.save(network.state_dict(), saving_path)
cfg.base.epoch = epoch
cfg.base.checkpoint_path = saving_path
with open(self.cfgfile, 'w') as fp:
json.dump(cfg, fp)
logger = None
if logger is None:
logger = Logger()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化