加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
config.py 6.52 KB
一键复制 编辑 原始数据 按行查看 历史
peviroy 提交于 2021-11-28 19:41 . Merge lyk's commit
import os
import logging
import pandas as pd
# ==============================================================================================================
# HyperParameters
# ==============================================================================================================
BATCH_SIZE = 64
EPOCHS = 800
LR = 5e-5
# ==============================================================================================================
# Logger
# ==============================================================================================================
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger()
# ==============================================================================================================
# Generated filedir & filepath
# ==============================================================================================================
root = os.path.dirname(__file__)
origin_file_dir = os.path.join(root, "origin")
data_file_dir = os.path.join(root, "data")
image_file_dir = os.path.join(root, "image")
# file directories
boxplot_image_file_dir = os.path.join(image_file_dir, "boxplot")
histogram_image_file_dir = os.path.join(image_file_dir, "histogram")
correlation_image_file_dir = os.path.join(image_file_dir, "correlation")
regression_image_file_dir = os.path.join(image_file_dir, "regression")
# data file save paths
extract_data_file_path = os.path.join(data_file_dir, "extract_data.npy")
feature_data_file_path = os.path.join(data_file_dir, "feature_data.npy")
washed_extract_data_file_path = os.path.join(data_file_dir, "washed_extract_data.npy")
washed_feature_data_file_path = os.path.join(data_file_dir, "washed_feature_data.npy")
excel_result_file_path = os.path.join(data_file_dir, "results.xlsx")
excel_washed_result_file_path = os.path.join(data_file_dir, "washed_results.xlsx")
# model file save paths
nn_model_checkpoint_file_path = os.path.join(
data_file_dir, f"nn_checkpoint_E{EPOCHS}.pth"
)
# svm_model_checkpoint_file_path = os.path.join(data_file_dir, f'svm_checkpoint.pth')
# elm_model_checkpoint_file_path = os.path.join(data_file_dir, f'elm_checkpoint.pth')
# Initialization
if not os.path.exists(data_file_dir):
os.mkdir(data_file_dir)
if not os.path.exists(image_file_dir):
os.mkdir(image_file_dir)
if not os.path.exists(boxplot_image_file_dir):
os.mkdir(boxplot_image_file_dir)
if not os.path.exists(histogram_image_file_dir):
os.mkdir(histogram_image_file_dir)
if not os.path.exists(correlation_image_file_dir):
os.mkdir(correlation_image_file_dir)
if not os.path.exists(regression_image_file_dir):
os.mkdir(regression_image_file_dir)
# ==============================================================================================================
# Data filedir & filepath
# ==============================================================================================================
data_file_dir_list = [
os.path.join(origin_file_dir, "2020年"),
os.path.join(origin_file_dir, "2021年"),
]
# Initialization
data_file_path_list = []
for date_file_dir in data_file_dir_list:
data_file_path_list.extend(
[
os.path.join(date_file_dir, file_name)
for file_name in os.listdir(date_file_dir)
]
)
# ==============================================================================================================
# Features
# ==============================================================================================================
feature_dict = {
"原矿": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"FeO": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"精矿": {"type": "in_feature", "start": 20, "interval": 4, "total": 6},
"综尾": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"二段混磁精": {"type": "out_feature", "start": 20, "interval": 4, "total": 6},
"强尾": {"type": "in_feature", "start": 20, "interval": 4, "total": 6},
"粒度1#": {"type": "in_feature", "start": 22, "interval": 4, "total": 6},
"粒度2#": {"type": "in_feature", "start": 22, "interval": 4, "total": 6},
"正浮精": {"type": "out_feature", "start": 20, "interval": 2, "total": 12},
"正浮尾": {"type": "out_feature", "start": 20, "interval": 4, "total": 6},
"反浮精": {"type": "out_feature", "start": None, "interval": None, "total": 1},
"反浮尾": {"type": "out_feature", "start": None, "interval": None, "total": 1},
"反一扫精": {"type": "out_feature", "start": 20, "interval": 2, "total": 12},
"浓度1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"浓度2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"流量1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"流量2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"KS-2粗1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"KS-2粗2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"KS-2精选": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"NaOH1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"NaOH2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"淀粉1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"淀粉2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"CaO1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"CaO2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
"激磁电流": {"type": "in_feature", "start": None, "interval": None, "total": 1},
}
dropped_feature = ["粒度1#", "粒度2#", "激磁电流"]
washed_feature_dict = feature_dict.copy()
for dropped_featname in dropped_feature:
washed_feature_dict.pop(dropped_featname)
# Initialization
feature_name_list = list(feature_dict.keys())
washed_feature_name_list = [
key for key in feature_dict.keys() if key not in dropped_feature
]
in_feature_index = [
i
for i, key in enumerate(washed_feature_dict)
if washed_feature_dict[key]["type"] == "in_feature"
]
out_feature_index = [
i
for i, key in enumerate(washed_feature_dict)
if washed_feature_dict[key]["type"] == "out_feature"
]
in_feature_name_list = [
key
for key in washed_feature_dict
if washed_feature_dict[key]["type"] == "in_feature"
]
out_feature_name_list = [
key
for key in washed_feature_dict
if washed_feature_dict[key]["type"] == "out_feature"
]
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化