加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
unicorn_main.py 12.40 KB
一键复制 编辑 原始数据 按行查看 历史
JinyuChata 提交于 2022-04-05 14:49 . output result
"""
TODO: train/test方法
- arg: tsv_list 均为tsv文件路径列表
- return: 返回是否成功
- 如果return False, 程序直接结束
- 有训练或者测试结果,可以保存到 RESULT_PATH
- 中间可以print结果/异常
- 直接崩溃退出的异常:
logger.error("message")
return False
- 不影响结果的异常:
logger.error("message")
...
return True
- 输出日志用logger.info, 输出test结果可以用logger.success
"""
import math
import os
import time
import numpy as np
import xxhash
from loguru import logger
from tqdm import tqdm
from unicorn.unicorn_parser import compare_edges
from utils import NORMAL_SKETCH_PATH, NORMAL_BASE_PATH, NORMAL_STREAM_PATH, TRAIN_NODES_MAP
from utils import ParseArguments
from utils import TEST_NODES_MAP, TRAIN_EDGES_MAP
from utils import TEST_SKETCH_PATH, TEST_BASE_PATH, TEST_STREAM_PATH, RESULT_LOG_PATH
# make argparse arguments global
CONSOLE_ARGUMENTS = ParseArguments()
CREATE_NO_WINDOW = 0x08000000
def hashgen(l):
"""Generate a single hash value from a list. """
hasher = xxhash.xxh64()
for e in l:
hasher.update(e)
return hasher.intdigest()
def edgegen(edge):
"""Generate a single hash value for a Auditd edge. """
l = list()
assert (edge[3])
l.append(edge[3])
return hashgen(l)
def nodegen(line):
"""Generate a single hash value for and Auditd node."""
lsrc = list()
assert (line[-2])
lsrc.append(line[-2])
lsrc.append(line[1])
ldst = list()
ldst.append(line[-1])
ldst.append(line[2])
return hashgen(lsrc) , hashgen(ldst)
def read_single_graph(file_name, nodes_map_file=None, edges_map_file=None):
"""Parsing edgelist from the output of prepare.py.
The format from prepare.py looks like:
<source_node_id> \t <destination_node_id> \t <hashed_source_type>:<hashed_destination_type>:<hashed_edge_type>:<edge_logical_timestamp>[:<timestamp_stats>]
The last '<timestamp_stats>' may or may not exist depending on whether the -s/-t option is set when running prepare.py.
Returned from this funtion is a list of edges, each of which is itself a list containing:
[source_node_id, destination_node_id, source_node_type, destination_node_type, edge_type, logical_timestamp, [timestamp,] source_node_seen, destination_node_seen]
The `timestamp` may or may not exist.
"""
map_id = dict() # maps original IDs to new IDs, which always start from 0
new_id = 0
graph = list() # list of parsed edges
edges_map = dict()
if nodes_map_file:
if os.path.exists(nodes_map_file):
map_id = np.load(nodes_map_file,allow_pickle=True).item()
if edges_map_file:
if os.path.exists(edges_map_file):
edges_map = np.load(edges_map_file,allow_pickle=True).item()
# description = '\x1b[6;30;42m[STATUS]\x1b[0m Sorting edges in Auditd data from {}'.format(file_name)
# pb = tqdm(desc=description, mininterval=1.0, unit=" edges")
with open(file_name, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
# pb.update() # for progress tracking
line = line.split()
if 'nametype' in line:
continue
try:
edge = []
edge.append(line[1])
edge.append(line[2])
# srcid dstid [hashed_source_type, hashed_destination_type, hashed_edge_type, edge_logical_timestamp, [timestamp_stats]]
# attributes = edge[2].strip().split(":")
source_node_type , destination_node_type = nodegen(line)
edge_type = edgegen(line)
# source_node_type = attributes[0] # hashed_source_type
# destination_node_type = attributes[1] # hashed_destination_type
# edge_type = attributes[2] # hashed_edge_type
edge_order = line[0] # edge_logical_timestamp
if CONSOLE_ARGUMENTS.stats:
ts = line[0] # timestamp_stats
elif CONSOLE_ARGUMENTS.jiffies:
ts = line[0] # CamFlow jiffies
# now we rearrange the edge vector:
# edge[0] is source_node_id, as orginally split
# edge[1] is destination_node_id, as originally split
edge.append(source_node_type)
edge.append(destination_node_type) # edge[3] = hashed_destination_type
edge.append(edge_type) # edge[4] = hashed_edge_type
edge.append(edge_order) # edge[5] = edge_logical_timestamp
if CONSOLE_ARGUMENTS.stats:
edge.append(ts) # optional: edge[6] = timestamp_stats
elif CONSOLE_ARGUMENTS.jiffies:
edge.append(ts) # optional: edge[6] = jiffies
graph.append(edge)
except:
logger.debug("{}".format(line))
f.close()
# pb.close()
# sort the graph edges based on logical timestamps
graph.sort(key=compare_edges)
# description = '\x1b[6;30;42m[STATUS]\x1b[0m Parsing edges in Auditd data (final stage) from {}'.format(file_name)
# pb = tqdm(desc=description, mininterval=1.0, unit=" edge")
for edge in graph:
# pb.update()
if edge[0] in map_id: # check if source ID has been seen before
edge[0] = map_id[edge[0]]
edge.append("0") # edge[6/7] = whether source node has been seen before
else:
edge.append("1")
map_id[edge[0]] = str(new_id)
edge[0] = str(new_id)
new_id = new_id + 1
if edge[1] in map_id: # check if destination ID has been seen before
edge[1] = map_id[edge[1]]
edge.append("0") # edge[7/8] = whether destination node has been seen before
else:
edge.append("1")
map_id[edge[1]] = str(new_id)
edge[1] = str(new_id)
new_id = new_id + 1
# record edges as a simple map
# {
# src_node_id: (dst_node_id, timestamp)[]
# }
edges_map.setdefault(edge[0], {'in': set(), 'out': set()})
edges_map.setdefault(edge[1], {'in': set(), 'out': set()})
edges_map[edge[0]]['out'].add((edge[1], edge[5]))
edges_map[edge[1]]['in'].add((edge[0], edge[5]))
# pb.close()
if nodes_map_file:
np.save(nodes_map_file,map_id)
if edges_map_file:
np.save(nodes_map_file,edges_map)
return graph
def run(args_input, args_base, args_stream,
args_jiffies=False, args_nodes_map=None, args_edges_map=None,
args_base_size=None, args_stats=False, args_stats_file='ts.txt', args_interval=0):
# because original codes are based on this global object, we need to hack it
CONSOLE_ARGUMENTS.jiffies = args_jiffies
CONSOLE_ARGUMENTS.stats = args_stats
graph = read_single_graph(args_input, args_nodes_map, args_edges_map)
# default to 10% of the total edges in the graph
base_graph_size = int(math.ceil(len(graph) * 0.1))
if args_base_size is not None:
base_graph_size = args_base_size
stream_graph_size = len(graph) - base_graph_size
base_file = open(args_base, "w+")
stream_file = open(args_stream, "w+")
if args_stats:
if not args_interval:
logger.debug("You must set -I if you choose to record runtime graph generation performance")
exit(1)
# for runtime performance eval.
ts_file = open(args_stats_file, "w")
# we use this flag to make sure we record the time it takes to create base graph only once.
recorded_once = False
edge_cnt = 0
for num, edge in enumerate(graph):
if num < base_graph_size:
if args_jiffies:
base_file.write(
"{} {} {}:{}:{}:{}:{}\n".format(edge[0], edge[1], edge[2], edge[3], edge[4], edge[5], edge[6]))
else:
base_file.write("{} {} {}:{}:{}:{}\n".format(edge[0], edge[1], edge[2], edge[3], edge[4], edge[5]))
else:
if args_stats:
stream_file.write(
"{} {} {}:{}:{}:{}:{}:{}:{}\n".format(edge[0], edge[1], edge[2], edge[3], edge[4], edge[7], edge[8],
edge[5], edge[6]))
elif args_jiffies:
stream_file.write(
"{} {} {}:{}:{}:{}:{}:{}:{}\n".format(edge[0], edge[1], edge[2], edge[3], edge[4], edge[7], edge[8],
edge[5], edge[6]))
else:
stream_file.write(
"{} {} {}:{}:{}:{}:{}:{}\n".format(edge[0], edge[1], edge[2], edge[3], edge[4], edge[6], edge[7],
edge[5]))
if args_stats:
edge_cnt += 1
if not recorded_once and edge_cnt == base_graph_size:
ts_file.write("{}\n".format(edge[6]))
edge_cnt = 0
recorded_once = True
if edge_cnt == args_interval:
ts_file.write("{}\n".format(edge[6]))
edge_cnt = 0
# record time for the last round of edges
if num == len(graph) - 1:
ts_file.write("{}\n".format(edge[6]))
logger.debug("\x1b[6;30;42m[SUCCESS]\x1b[0m Graph {} is processed.".format(args_input))
logger.debug("\x1b[6;30;42m[SUCCESS]\x1b[0m Base graph of size {} is located at {}".format(base_graph_size, args_base))
logger.debug(
"\x1b[6;30;42m[SUCCESS]\x1b[0m Stream graph of size {} is located at {}".format(stream_graph_size, args_stream))
if args_stats:
logger.debug("\x1b[6;30;42m[SUCCESS]\x1b[0m Time information is located at {}".format(args_stats_file))
base_file.close()
stream_file.close()
if args_stats:
ts_file.close()
def train_tsvs(tsv_list, training_ds, debug_log_file) -> bool:
pbar = tqdm(total=len(tsv_list), initial=0, unit='it', unit_scale=True, desc="训练中")
for tsv_file in tsv_list:
logger.debug("训练: {} -> {}", tsv_file, debug_log_file)
tmp_tsv = tsv_file.rstrip("/\\")
file_idx = (((tmp_tsv.split("\\")[-1]).split("/")[-1]).split("-")[-1]).replace(".tsv", "").split("_")[-1]
# len([name for name in os.listdir(NORMAL_BASE_PATH) if os.path.isfile(os.path.join(NORMAL_BASE_PATH, name))])
NORMAL_BASE_FILE = NORMAL_BASE_PATH + training_ds + f'-{file_idx}.txt'
NORMAL_STREAM_FILE = NORMAL_STREAM_PATH + training_ds + f'-{file_idx}.txt'
NORMAL_SKETCH_FILE = NORMAL_SKETCH_PATH + training_ds + f'-{file_idx}.txt'
TRAIN_NODES_MAP_FILE = TRAIN_NODES_MAP + training_ds + f'{file_idx}-nodes_map.npy'
TRAIN_EDGES_MAP_FILE = TRAIN_EDGES_MAP + training_ds + f'{file_idx}-edges_map.npy'
run(tsv_file,NORMAL_BASE_FILE,NORMAL_STREAM_FILE,args_nodes_map= TRAIN_NODES_MAP_FILE, args_edges_map= TRAIN_EDGES_MAP_FILE)
os.system('./sketch.sh '+ NORMAL_BASE_FILE + ' ' + NORMAL_STREAM_FILE + ' ' + NORMAL_SKETCH_FILE + ' ' + debug_log_file)
time.sleep(0.1)
pbar.update(1)
pbar.close()
return True
def test_tsvs(tsv_list, testing_ds, debug_log_file) -> bool:
pbar = tqdm(total=len(tsv_list), initial=0, unit='it', unit_scale=True, desc="测试中")
for tsv_file in tsv_list:
logger.debug("测试: {}", tsv_file)
tmp_tsv = tsv_file.rstrip("/\\")
file_idx = (((tmp_tsv.split("\\")[-1]).split("/")[-1]).split("-")[-1]).replace(".tsv", "").split("_")[-1]
# len([name for name in os.listdir(NORMAL_BASE_PATH) if os.path.isfile(os.path.join(NORMAL_BASE_PATH, name))])
TEST_BASE_FILE = TEST_BASE_PATH + testing_ds + f'-{file_idx}.txt'
TEST_STREAM_FILE = TEST_STREAM_PATH + testing_ds + f'-{file_idx}.txt'
TEST_SKETCH_FILE = TEST_SKETCH_PATH + testing_ds + f'-{file_idx}.txt'
RESULT_LOG_FILE = RESULT_LOG_PATH + testing_ds + f'.txt'
TEST_NODES_MAP_FILE = TEST_NODES_MAP + testing_ds + f'{file_idx}-nodes_map.npy'
TEST_EDGES_MAP_FILE = TRAIN_EDGES_MAP + testing_ds + f'{file_idx}-edges_map.npy'
run(tsv_file,TEST_BASE_FILE,TEST_STREAM_FILE, args_nodes_map = TEST_NODES_MAP_FILE, args_edges_map = TEST_EDGES_MAP_FILE)
os.system('./sketch.sh '+ TEST_BASE_FILE + ' ' + TEST_STREAM_FILE + ' ' + TEST_SKETCH_FILE + ' ' + debug_log_file)
os.system('./model.sh '+ NORMAL_SKETCH_PATH + ' ' + TEST_SKETCH_FILE + ' ' + RESULT_LOG_FILE)
time.sleep(0.1)
pbar.update(1)
pbar.close()
return True
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化