加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 5.17 KB
一键复制 编辑 原始数据 按行查看 历史
JinyuChata 提交于 2022-04-05 10:58 . reset to interactive
import shutil
import requests
from loguru import logger
from tqdm import tqdm
import json
import os
from collections.abc import Set
DEBUG_LOG_PATH = "log"
TEMP_ROOT_PATH = "temp"
DOWNLOAD_PATH = os.sep.join(["temp", "zip"])
TSV_PATH = os.sep.join(["temp", "tsv"])
PROCESS_PATH = os.sep.join(["temp", "process.json"])
RESULT_PATH = "res"
NORMAL_PATH = RESULT_PATH + "/normal/"
NORMAL_BASE_PATH = RESULT_PATH + '/normal/base/'
NORMAL_STREAM_PATH = RESULT_PATH + '/normal/stream/'
NORMAL_SKETCH_PATH = RESULT_PATH + '/normal/sketch/'
TEST_PATH = RESULT_PATH + "/test/"
TEST_BASE_PATH = RESULT_PATH + '/test/base/'
TEST_STREAM_PATH = RESULT_PATH + '/test/stream/'
TEST_SKETCH_PATH = RESULT_PATH + '/test/sketch/'
RESULT_LOG_PATH = RESULT_PATH + '/test_res/'
TEST_NODES_MAP = RESULT_PATH + '/test/'
TEST_EDGES_MAP = RESULT_PATH + '/test/'
TRAIN_NODES_MAP = RESULT_PATH + '/normal/'
TRAIN_EDGES_MAP = RESULT_PATH + '/normal/'
# tsv文件做排序
def sort_tsv_filelist(filelist):
sort_list = []
for filepath in filelist:
filepath = filepath.rstrip("/\\")
serial_str = (((filepath.split("\\")[-1]).split("/")[-1]).split("-")[-1]).replace(".tsv", "").split("_")[-1]
try:
serial = int(serial_str)
sort_list.append((serial, filepath))
except ValueError as ve:
raise ConnectorException(f"tsv文件名 {filepath.split('/')[-1]} 不符合规范,无法解析文件顺序!")
sort_list.sort(key=lambda v: v[0])
res_sort_list = list(map(lambda v: v[1], sort_list))
return res_sort_list
if __name__ == '__main__':
ids = ["a\\b\\c\\2.tsv", "a\\b\\c\\1.tsv"]
print(str(sort_tsv_filelist(ids)))
# init_folders
def set_up_paths():
for path in [DOWNLOAD_PATH, TSV_PATH, RESULT_PATH,
NORMAL_BASE_PATH, NORMAL_STREAM_PATH, NORMAL_SKETCH_PATH,
TEST_BASE_PATH, TEST_STREAM_PATH, TEST_SKETCH_PATH,
RESULT_LOG_PATH ]:
if not os.path.exists(path):
os.makedirs(path)
def remove_all_in_dir(path):
if not os.path.exists(path) or not os.path.isdir(path):
return
for i in os.listdir(path):
c_path = os.path.join(path, i)
if os.path.isdir(c_path):
shutil.rmtree(c_path)
else:
os.remove(c_path)
def reset_history():
if os.path.exists(RESULT_PATH):
remove_all_in_dir(RESULT_PATH)
if os.path.exists(TEMP_ROOT_PATH):
remove_all_in_dir(TEMP_ROOT_PATH)
class Set2ListEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Set):
return list(obj)
else:
return json.JSONEncoder.default(self, obj)
class ParseArguments:
def __init__(self):
self.input = ''
self.base_size = 0
self.base = ''
self.stream = ''
self.stats = False
self.interval = 0
self.stats_file = 'ts.txt'
self.jiffies = False
self.nodes_map = ''
class ProcessType:
training = "training" # 尚未训练
testing = "testing" # 已经训练,可以执行测试
class Procedure:
def __init__(self, training_ds="", testing_ds="", server_host="", user_name=""):
self.process = ProcessType.training
self.training_ds = training_ds
self.testing_ds = testing_ds
self.tested_tsv = []
self.server_host = server_host
self.user_name = user_name
@classmethod
def load_from_json(cls):
if not os.path.exists(PROCESS_PATH):
return cls()
with open(PROCESS_PATH, 'r') as f:
res_json = json.load(f)
res = cls(training_ds=res_json['training_ds'],
server_host=res_json["server_host"],
user_name=res_json["user_name"])
res.process = res_json['process']
return res
def save_to_json(self):
to_save = {'process': self.process, 'training_ds': self.training_ds,
'server_host': self.server_host, 'user_name': self.user_name}
with open(PROCESS_PATH, 'w') as f:
json.dump(to_save, f)
class ConnectorException(Exception):
def __init__(self, msg):
self.msg = msg
def download_from_url(url, params, dst):
"""
@param: url to download file
@param: dst place to put the file
:return: bool
"""
# 访问url进行下载
req = requests.get(url, stream=True, params=params)
if req.headers['content-type'] != 'application/zip':
if req.status_code == 200:
raise ConnectorException(req.json()['msg']) # 鉴权不成功
else:
logger.error("下载失败, code={}, content={}", str(req.status_code), str(req.content))
return False
try:
file_size = int(req.headers['content-length'])
pbar = tqdm(
total=file_size, initial=0,
unit='B', unit_scale=True, desc=url.split('/')[-1])
with(open(dst, 'ab')) as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
except Exception as e:
logger.error("下载失败, 下载{}异常: {}", url, str(e))
return False
pbar.close()
return True
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化