加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 2.14 KB
一键复制 编辑 原始数据 按行查看 历史
ping 提交于 2023-11-02 09:26 . init
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'Stefan Jansen'
import numpy as np
np.random.seed(42)
def format_time(t):
"""Return a formatted time string 'HH:MM:SS
based on a numeric time() value"""
m, s = divmod(t, 60)
h, m = divmod(m, 60)
return f'{h:0>2.0f}:{m:0>2.0f}:{s:0>2.0f}'
class MultipleTimeSeriesCV:
"""Generates tuples of train_idx, test_idx pairs
Assumes the MultiIndex contains levels 'symbol' and 'date'
purges overlapping outcomes"""
def __init__(self,
n_splits=3,
train_period_length=126,
test_period_length=21,
lookahead=None,
date_idx='date',
shuffle=False):
self.n_splits = n_splits
self.lookahead = lookahead
self.test_length = test_period_length
self.train_length = train_period_length
self.shuffle = shuffle
self.date_idx = date_idx
def split(self, X, y=None, groups=None):
unique_dates = X.index.get_level_values(self.date_idx).unique()
days = sorted(unique_dates, reverse=True)
split_idx = []
for i in range(self.n_splits):
test_end_idx = i * self.test_length
test_start_idx = test_end_idx + self.test_length
train_end_idx = test_start_idx + self.lookahead - 1
train_start_idx = train_end_idx + self.train_length + self.lookahead - 1
split_idx.append([train_start_idx, train_end_idx,
test_start_idx, test_end_idx])
dates = X.reset_index()[[self.date_idx]]
for train_start, train_end, test_start, test_end in split_idx:
train_idx = dates[(dates[self.date_idx] > days[train_start])
& (dates.date <= days[train_end])].index
test_idx = dates[(dates.date > days[test_start])
& (dates.date <= days[test_end])].index
if self.shuffle:
np.random.shuffle(list(train_idx))
yield train_idx.to_numpy(), test_idx.to_numpy()
def get_n_splits(self, X, y, groups=None):
return self.n_splits
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化