加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
load_data.py 4.31 KB
一键复制 编辑 原始数据 按行查看 历史
xuanlei 提交于 2017-08-04 14:25 . 模型训练
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 13:32:26 2017
@author: xuanlei
"""
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import normalize
import matplotlib as mpl
#==============================================================================
# Data Preprocessing
#==============================================================================
data_dir = 'C:\\Users\\www\Desktop\\SDATA\\code\\data\\'
def data_load():
globals()['wtno_num'] = os.listdir(data_dir)
for item in wtno_num:
if item == 'gy_contest_link_traveltime_training_data.txt':
print('>>>>>>>>>>>>开始读取'+item+'<<<<<<<<<<<<<<<<<')
globals()[item[:-4]] = pd.read_table(data_dir+item,sep = ';',header=0)
else:
print('>>>>>>>>>>>>开始读取'+item+'<<<<<<<<<<<<<<<<<')
globals()[item[:-4]] = pd.read_csv(data_dir+item,sep = ';',header=0)
print('>>>>>>>>>>>>读取完成<<<<<<<<<<<<<<<<<')
#==============================================================================
# Data Extract:
# row: time
# cloumns: link_time
# shape(len(time),132)
#==============================================================================
def get_all_data(travel_data):
travel_data.index = travel_data.time_interval
globals()['temp_data'] = travel_data.sort_values(by = 'time_interval')
sort_data = globals()['temp_data'].drop(['time_interval','date'],axis = 1)
linkid = list(set(list(sort_data['link_ID'])))
linkid.sort()
times = list(set(list(sort_data.index)))
times.sort()
all_data = []
i = 0
for t in times:
# time.sleep(0.001)
t_list = []
tempt = sort_data.loc[t]
for link in linkid:
if len(tempt[tempt.link_ID==link].values) != 0:
t_list.append(tempt[tempt.link_ID==link].values[0][1])
else:
t_list.append('nan')
i+=1
print('>>>>>>>>>>>>完成第{0}条,{1}提取<<<<<<<<<<<<<<<<<'.format(i,t))
all_data.append(t_list)
globals()['temp'] = all_data
print('>>>>>>>>>>>>完成全部提取:%d<<<<<<<<<<<<<<<<<'%i)
df = pd.DataFrame(all_data)
df.columns = [str(x) for x in linkid]
df.index = times
return df,df.columns
#==============================================================================
# Data Extract: fill na
#==============================================================================
def deal_na(df):
dfna = df.replace('nan',float('nan'))
dfna = dfna.fillna(method='pad')
df_result = dfna.dropna()
return df_result
#==============================================================================
# Data Extract:
# 生成时间序列的滑窗数据,格式为[[[],[]],[[],[]],...[[],[]]]
# 滑窗参数有两个一个是步长gap,决定滑窗之间的间隔;另一个是num,决定滑窗内的数据条数
#==============================================================================
def get_window_data(df,num,gap):
rows = df.shape[0]
window_num = math.floor(rows/gap)
index = 0
result_list = []
for i in range(window_num-1):
window_data = df.iloc[index:index+num+1,:]
tran_window_data = window_data.iloc[0:-1,:]
tran_window_label = window_data.iloc[1:,:]
if (tran_window_data.shape[0]==num)&(tran_window_label.shape[0]==30):
tran_result_data = [tran_window_data, tran_window_label]
index += gap
result_list.append(tran_result_data)
return result_list
#==============================================================================
# start function
#==============================================================================
def start():
data_load()
print('>>>>>>>>>>>>完成数据加载<<<<<<<<<<<<<<<<<')
df,fl = get_all_data(gy_contest_link_traveltime_training_data)
print('>>>>>>>>>>>>完成数据抽取<<<<<<<<<<<<<<<<<')
dfna = deal_na(df)
print('>>>>>>>>>>>>完成数据填充<<<<<<<<<<<<<<<<<')
result = get_window_data(dfna,30,2) #num和gap分别取30,2
print('>>>>>>>>>>>>完成数据滑窗生成<<<<<<<<<<<<<<<<<')
return result,list(fl)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化