加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
datasets.py 1.60 KB
一键复制 编辑 原始数据 按行查看 历史
import h5py
import math
import pandas as pd
from tensorflow.keras.utils import Sequence
import numpy as np
class ECGSequence(Sequence):
@classmethod
def get_train_and_val(cls, path_to_hdf5, hdf5_dset, path_to_csv, batch_size=8, val_split=0.02):
n_samples = len(pd.read_csv(path_to_csv))
n_train = math.ceil(n_samples*(1-val_split))
train_seq = cls(path_to_hdf5, hdf5_dset, path_to_csv, batch_size, end_idx=n_train)
valid_seq = cls(path_to_hdf5, hdf5_dset, path_to_csv, batch_size, start_idx=n_train)
return train_seq, valid_seq
def __init__(self, path_to_hdf5, hdf5_dset, path_to_csv=None, batch_size=8,
start_idx=0, end_idx=None):
if path_to_csv is None:
self.y = None
else:
self.y = pd.read_csv(path_to_csv).values
# Get tracings
self.f = h5py.File(path_to_hdf5, "r")
self.x = self.f[hdf5_dset]
self.batch_size = batch_size
if end_idx is None:
end_idx = len(self.x)
self.start_idx = start_idx
self.end_idx = end_idx
@property
def n_classes(self):
return self.y.shape[1]
def __getitem__(self, idx):
start = self.start_idx + idx * self.batch_size
end = min(start + self.batch_size, self.end_idx)
if self.y is None:
return np.array(self.x[start:end, :, :])
else:
return np.array(self.x[start:end, :, :]), np.array(self.y[start:end])
def __len__(self):
return math.ceil((self.end_idx - self.start_idx) / self.batch_size)
def __del__(self):
self.f.close()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化