加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
lstm.py 1.24 KB
一键复制 编辑 原始数据 按行查看 历史
jaungiers 提交于 2017-07-15 16:57 . Initial code and data commit
import os
import time
import json
import warnings
import numpy as np
from numpy import newaxis
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.models import Sequential
from keras.models import load_model
configs = json.loads(open(os.path.join(os.path.dirname(__file__), 'configs.json')).read())
warnings.filterwarnings("ignore") #Hide messy Numpy warnings
def build_network(layers):
model = Sequential()
model.add(LSTM(
input_dim=layers[0],
output_dim=layers[1],
return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(
layers[2],
return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(
output_dim=layers[3]))
model.add(Activation("tanh"))
start = time.time()
model.compile(
loss=configs['model']['loss_function'],
optimizer=configs['model']['optimiser_function'])
print("> Compilation Time : ", time.time() - start)
return model
def load_network(filename):
#Load the h5 saved model and weights
if(os.path.isfile(filename)):
return load_model(filename)
else:
print('ERROR: "' + filename + '" file does not exist as a h5 model')
return None
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化