加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
Train_Custom_Features.py 2.83 KB
一键复制 编辑 原始数据 按行查看 历史
nreimers 提交于 2018-04-20 12:20 . Update code to Keras 2.1.5
# This script trains the BiLSTM-CNN-CRF architecture with customly defined features.
# You can specify which features the network should use by changing the featureNames-parameter.
# Per default, the networks uses tokens and casing as features.
# The input data contains a column with POS data, which we can use for training the network.
# Note: When you use 'POS' as a feature, then this feature must also be in the dataset when you apply the network
# to new data.
from __future__ import print_function
import os
import logging
import sys
from neuralnets.BiLSTM import BiLSTM
from util.preprocessing import perpareDataset, loadDatasetPickle
from keras import backend as K
# :: Change into the working dir of the script ::
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
# :: Logging level ::
loggingLevel = logging.INFO
logger = logging.getLogger()
logger.setLevel(loggingLevel)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(loggingLevel)
formatter = logging.Formatter('%(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
######################################################
#
# Data preprocessing
#
######################################################
datasets = {
'conll2000_chunking': #Name of the dataset
{'columns': {0:'tokens', 1:'POS', 2:'chunk_BIO'}, #CoNLL format for the input data. Column 0 contains tokens, column 2 contains POS and column 2 contains chunk information using BIO encoding
'label': 'chunk_BIO', #Which column we like to predict
'evaluate': True, #Should we evaluate on this task? Set true always for single task setups
'commentSymbol': None} #Lines in the input data starting with this string will be skipped. Can be used to skip comments
}
# :: Path on your computer to the word embeddings. Embeddings by Komninos et al. will be downloaded automatically ::
embeddingsPath = 'komninos_english_embeddings.gz'
# :: Prepares the dataset to be used with the LSTM-network. Creates and stores cPickle files in the pkl/ folder ::
pickleFile = perpareDataset(embeddingsPath, datasets)
######################################################
#
# The training of the network starts here
#
######################################################
#Load the embeddings and the dataset
embeddings, mappings, data = loadDatasetPickle(pickleFile)
# Some network hyperparameters
params = {'classifier': ['CRF'], 'LSTM-Size': [100, 100], 'dropout': (0.25, 0.25),
'featureNames': ['tokens', 'casing', 'POS'], 'addFeatureDimensions': 10}
model = BiLSTM(params)
model.setMappings(mappings, embeddings)
model.setDataset(datasets, data)
model.modelSavePath = "models/[ModelName]_[DevScore]_[TestScore]_[Epoch].h5"
model.fit(epochs=25)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化