加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
helper.py 2.42 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-04 17:20 . 1.3.1
import numpy as np
from MyTT import *
from define import *
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from training_env import TrainingEnv
from stable_baselines3.common.vec_env import SubprocVecEnv,DummyVecEnv
MAX_MONEY = 10000000
MAX_VOLUME = 1000000
MAX_PIRICE = 100000
MACD_LIMIT = 100
KDJ_LIMIT = 100
CCI_LIMIT = 100
WR_LIMIT = 100
PRY_LIMIT = 100
BRAR_LIMIT = 100
def get_obs(tick_data,holder_data):
obs = np.zeros(shape=(OBS_COUNT),dtype=np.float32)
offset = 0
standard = np.float32(tick_data["standard"])
obs[offset]=(np.float32(tick_data["price"])-standard)/MAX_PIRICE
obs[offset+1]=np.float32(tick_data["volume"]-holder_data["last_valume"])/MAX_VOLUME
obs[offset+2]=np.float32(tick_data["delta_hold"])/MAX_VOLUME
offset += 3
obs[offset+0]=(np.float32(tick_data["buy_price"])-standard)/MAX_PIRICE
obs[offset+1]=(np.float32(tick_data["sell_price"])-standard)/MAX_PIRICE
obs[offset+2]=(np.float32(tick_data["buy_volume"])-standard)/MAX_VOLUME
obs[offset+3]=(np.float32(tick_data["sell_volume"])-standard)/MAX_VOLUME
offset += 4
obs[offset+0]=holder_data["buy_order"]/MAX_ORDER
obs[offset+1]=holder_data["sell_order"]/MAX_ORDER
usable_order = MAX_ORDER - holder_data["buy_order"] - holder_data["sell_order"]
obs[offset+2]=usable_order/MAX_ORDER
return obs
def make_env(rank, trade_day,seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(TrainingEnv(trade_day), MODEL_PATH+"/"+str(rank))
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
def get_subproc_env(num,is_evaluate=False):
if is_evaluate:
return SubprocVecEnv([make_env(i,EVALUATE_TRADE_DAY) for i in range(num)])
else:
return SubprocVecEnv([make_env(i,TRAINING_TRADE_DAY) for i in range(num)])
def get_dummy_env(num,is_evaluate=False):
if is_evaluate:
EVALUATE_TRADE_DAY=["rb次主力连续_20220804.csv"]
return DummyVecEnv([make_env(i,EVALUATE_TRADE_DAY) for i in range(num)])
else:
return DummyVecEnv([make_env(i,TRAINING_TRADE_DAY) for i in range(num)])
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化