加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
network.h 23.30 KB
一键复制 编辑 原始数据 按行查看 历史
ruki 提交于 2013-07-02 14:09 . ...
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
#ifndef NETWORK_H
#define NETWORK_H
#include "prefix.h"
#include "imageconvector.h"
#include <QFile>
#include <QTextStream>
#define TRAIN_COUNT (200)
template<typename_param_k Wgts>
inline void save_weights(QTextStream& fout, Wgts const& wgts)
{
fout << wgts.size() << endl;
for (int i = 0; i < (int)wgts.size(); ++i)
{
fout << wgts[i].value() << endl;
}
}
template<typename_param_k Wgts>
inline Wgts load_weights(QTextStream& fin)
{
typedef Wgts weights_type;
typedef typename_type_k weights_type::size_type size_type;
typedef typename_type_k weights_type::value_type weight_type;
typedef typename_type_k weight_type::float_type value_type;
// read the number of the weights
size_type wgts_n;
fin >> wgts_n;
weights_type wgts(wgts_n);
// read weights
for (int i = 0; i < (int)wgts_n; ++i)
{
value_type val;
fin >> val;
wgts[i].value(val);
}
return wgts;
}
template<typename_param_k Vtr>
inline void save_vector(QTextStream& fout, Vtr const& v)
{
fout << v.size() << endl;
for (int i = 0; i < (int)v.size(); ++i)
{
fout << v[i] << endl;
}
}
template<typename_param_k Vtr>
inline Vtr load_vector(QTextStream& fin)
{
typedef Vtr vector_type;
typedef typename_type_k vector_type::size_type size_type;
// read the number of the weights
size_type n;
fin >> n;
vector_type v(n);
// read weights
for (int i = 0; i < (int)n; ++i) fin >> v[i];
return v;
}
// bp_policy
struct bp_policy
{
public:
typedef bp_network<NETWORK_INPUT_N, NETWORK_OUTPUT_N> network_type;
typedef network_type::layers_type layers_type;
typedef network_type::sample_type sample_type;
typedef network_type::samples_type samples_type;
typedef network_type::weights_type weights_type;
typedef network_type::size_type size_type;
typedef network_type::float_type float_type;
typedef basic_network_validator<network_type> validator_type;
private:
network_type m_network;
validator_type m_validator;
public:
bp_policy()
: m_network(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
{}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
validator_type& validator() { return m_validator; }
validator_type const& validator() const { return m_validator; }
/// \name Methods
/// @{
public:
void train(samples_type& sps)
{
network().train(sps, TRAIN_COUNT);
}
void run(sample_type& sp)
{
network().run(sp);
}
void save()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::WriteOnly | QIODevice::Text))
return;
QTextStream fout(&file);
// save signature
fout << "bp_network" << endl;
// save weights
save_weights(fout, network().weights());
}
bool load()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::ReadOnly))
return false;
QTextStream fin(&file);
// read signature
QString sig;
fin >> sig;
if (sig != "bp_network") return false;
// load weights
network().weights(load_weights<weights_type>(fin));
return true;
}
QString validate(samples_type& sps)
{
validator().validate(network(), sps);
return QObject::tr("bp network:: mse:%1 erate:%2").arg(validator().mse()).arg(validator().erate());
}
/// @}
};
// wga_policy
struct wga_policy
{
public:
typedef bp_network<NETWORK_INPUT_N, NETWORK_OUTPUT_N> bp_network_type;
typedef wga_network<bp_network_type> network_type;
typedef bp_network_type::layers_type layers_type;
typedef bp_network_type::sample_type sample_type;
typedef bp_network_type::samples_type samples_type;
typedef bp_network_type::weights_type weights_type;
typedef bp_network_type::size_type size_type;
typedef bp_network_type::float_type float_type;
typedef basic_network_validator<network_type> validator_type;
private:
bp_network_type m_bp;
network_type m_network;
validator_type m_validator;
public:
wga_policy()
: m_bp(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_network(m_bp)
{}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
validator_type& validator() { return m_validator; }
validator_type const& validator() const { return m_validator; }
/// \name Methods
/// @{
public:
void train(samples_type& sps)
{
network().train(sps, TRAIN_COUNT);
}
void run(sample_type& sp)
{
network().run(sp);
}
void save()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::WriteOnly | QIODevice::Text))
return;
QTextStream fout(&file);
// save signature
fout << "wga_network" << endl;
// save weights
save_weights(fout, m_bp.weights());
}
bool load()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::ReadOnly))
return false;
QTextStream fin(&file);
// read signature
QString sig;
fin >> sig;
if (sig != "wga_network") return false;
// load weights
m_bp.weights(load_weights<weights_type>(fin));
return true;
}
QString validate(samples_type& sps)
{
validator().validate(network(), sps);
return QObject::tr("wga network:: mse:%1 erate:%2 generation:%3").arg(validator().mse()).arg(validator().erate()).arg(network().generation());
}
/// @}
};
// bagging_policy
struct bagging_policy
{
public:
typedef bp_network<NETWORK_INPUT_N, NETWORK_OUTPUT_N> bp_network_type;
typedef bagging_networks<bp_network_type> network_type;
typedef bp_network_type::layers_type layers_type;
typedef bp_network_type::sample_type sample_type;
typedef bp_network_type::samples_type samples_type;
private:
bp_network_type m_bp_0;
bp_network_type m_bp_1;
bp_network_type m_bp_2;
bp_network_type m_bp_3;
bp_network_type m_bp_4;
bp_network_type m_bp_5;
bp_network_type m_bp_6;
bp_network_type m_bp_7;
bp_network_type m_bp_8;
bp_network_type m_bp_9;
network_type m_network;
public:
bagging_policy()
: m_bp_0(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_1(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_2(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_3(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_4(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_5(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_6(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_7(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_8(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_bp_9(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF)
, m_network(&m_bp_0, &m_bp_1, &m_bp_2, &m_bp_3, &m_bp_4, &m_bp_5, &m_bp_6, &m_bp_7, &m_bp_8, &m_bp_9)
{}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
};
// ada_boosting_policy
struct ada_boosting_policy
{
public:
typedef bp_network < NETWORK_INPUT_N
, NETWORK_OUTPUT_N
, sample_selector<NETWORK_INPUT_N, NETWORK_OUTPUT_N>::float_sample_type
> bp_network_type;
typedef ada_boosting_n_networks<bp_network_type> network_type;
typedef bp_network_type::layers_type layers_type;
typedef bp_network_type::sample_type sample_type;
typedef bp_network_type::samples_type samples_type;
typedef bp_network_type::weights_type weights_type;
typedef bp_network_type::size_type size_type;
typedef bp_network_type::float_type float_type;
typedef basic_network_validator<network_type> validator_type;
typedef network_type::floats_type floats_type;
private:
bp_network_type m_bp_0;
bp_network_type m_bp_1;
bp_network_type m_bp_2;
bp_network_type m_bp_3;
bp_network_type m_bp_4;
bp_network_type m_bp_5;
bp_network_type m_bp_6;
bp_network_type m_bp_7;
bp_network_type m_bp_8;
bp_network_type m_bp_9;
network_type m_network;
validator_type m_validator;
public:
ada_boosting_policy()
{
for (e_size_t i = 0; i < 10; ++i)
{
m_network.push_back(bp_network_type(layers_type(10), FACTOR_BP_LRATE, FACTOR_BP_MF));
}
}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
validator_type& validator() { return m_validator; }
validator_type const& validator() const { return m_validator; }
/// \name Methods
/// @{
public:
void train(samples_type& sps)
{
network().train(sps, TRAIN_COUNT);
}
void run(sample_type& sp)
{
network().run(sp);
}
void save()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::WriteOnly | QIODevice::Text))
return;
QTextStream fout(&file);
// save signature
fout << "ada_boosting_network" << endl;
size_type i;
// save the weights of the whole network
floats_type wgts = network().weights();
fout << wgts.size() << endl;
for (i = 0; i < wgts.size(); ++i)
{
fout << wgts[i] << endl;
}
// save the weights of the every base network
size_type nets_n = network().networks().size();
fout << nets_n << endl;
for (i = 0; i < nets_n; ++i)
save_weights(fout, network().network(i).weights());
}
bool load()
{
QFile file(NETWORK_DB_PATH);
if (!file.open(QIODevice::ReadOnly))
return false;
QTextStream fin(&file);
// read signature
QString sig;
fin >> sig;
if (sig != "ada_boosting_network") return false;
// read the number of the weights
size_type wgts_n;
fin >> wgts_n;
floats_type wgts(wgts_n);
// load the weights of the whole network
size_type i;
for (i = 0; i < wgts_n; ++i)
{
float_type val;
fin >> val;
wgts[i] = val;
}
network().weights(wgts);
// read the number of the networks
size_type nets_n;
fin >> nets_n;
// load the weights of the every base network
for (i = 0; i < nets_n; ++i)
network().network(i).weights(load_weights<weights_type>(fin));
return true;
}
QString validate(samples_type& sps)
{
validator().validate(network(), sps);
return QObject::tr("ada-boosting network:: mse:%1 erate:%2").arg(validator().mse()).arg(validator().erate());
}
/// @}
};
#if 0
// wga_ada_boosting_policy
struct wga_ada_boosting_policy
{
public:
typedef bp_network<NETWORK_INPUT_N, 4> bp_network_type;
typedef wga_network<bp_network_type> wga_network_type;
typedef ada_boosting_networks<wga_network_type> network_type;
typedef bp_network_type::layers_type layers_type;
typedef bp_network_type::sample_type sample_type;
typedef bp_network_type::samples_type samples_type;
private:
bp_network_type m_bp_0;
/*bp_network_type m_bp_1;
bp_network_type m_bp_2;
bp_network_type m_bp_3;
bp_network_type m_bp_4;
bp_network_type m_bp_5;
bp_network_type m_bp_6;
bp_network_type m_bp_7;
bp_network_type m_bp_8;
bp_network_type m_bp_9;*/
wga_network_type m_wga_0;
/*wga_network_type m_wga_1;
wga_network_type m_wga_2;
wga_network_type m_wga_3;
wga_network_type m_wga_4;
wga_network_type m_wga_5;
wga_network_type m_wga_6;
wga_network_type m_wga_7;
wga_network_type m_wga_8;
wga_network_type m_wga_9;*/
network_type m_network;
public:
wga_ada_boosting_policy()
: m_bp_0(layers_type(10), 0.5)
/*, m_bp_1(layers_type(10), 0.5)
, m_bp_2(layers_type(10), 0.5)
, m_bp_3(layers_type(10), 0.5)
, m_bp_4(layers_type(10), 0.5)
, m_bp_5(layers_type(10), 0.5)
, m_bp_6(layers_type(10), 0.5)
, m_bp_7(layers_type(10), 0.5)
, m_bp_8(layers_type(10), 0.5)
, m_bp_9(layers_type(10), 0.5)*/
, m_wga_0(m_bp_0)
/*, m_wga_1(m_bp_1)
, m_wga_2(m_bp_2)
, m_wga_3(m_bp_3)
, m_wga_4(m_bp_4)
, m_wga_5(m_bp_5)
, m_wga_6(m_bp_6)
, m_wga_7(m_bp_7)
, m_wga_8(m_bp_8)
, m_wga_9(m_bp_9)*/
, m_network(&m_wga_0/*, &m_wga_1, &m_wga_2, &m_wga_3, &m_wga_4, &m_wga_5, &m_wga_6, &m_wga_7, &m_wga_8, &m_wga_9*/)
{}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
};
#endif
#ifdef USING_PCA
// pca_mixed_policy
template<typename_param_k Py>
struct pca_mixed_policy
{
public:
typedef Py basic_policy_type;
typedef typename_type_k basic_policy_type::network_type basic_network_type;
typedef ghia_network<PCA_INPUT_N, PCA_OUTPUT_N> pca_type;
typedef pca_mixed_network<pca_type, basic_network_type> network_type;
typedef typename_type_k network_type::sample_type sample_type;
typedef typename_type_k network_type::samples_type samples_type;
typedef typename_type_k network_type::size_type size_type;
typedef typename_type_k network_type::float_type float_type;
typedef basic_network_validator<network_type> validator_type;
typedef typename_type_k pca_type::vector_type vector_type;
private:
network_type m_network;
validator_type m_validator;
public:
pca_mixed_policy()
: m_network(pca_type(), basic_policy_type().network())
{
}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
validator_type& validator() { return m_validator; }
validator_type const& validator() const { return m_validator; }
/// \name Methods
/// @{
public:
void train(samples_type& sps)
{
// train
network().train(sps, TRAIN_COUNT, 10);
}
void run(sample_type& sp)
{
network().run(sp);
}
void save()
{
QFile file(PCA_DB_PATH);
if (!file.open(QIODevice::WriteOnly | QIODevice::Text))
return;
QTextStream fout(&file);
// save signature
fout << "pca" << endl;
// save the weights
size_type vecs_n = network().pnet().vectors_size();
for (size_type i = 0; i < vecs_n; ++i)
save_vector(fout, network().pnet().vector(i));
// save converters
save_vector(fout, network().pnet().converter().avgs());
save_vector(fout, network().pnet().converter().sds());
save_vector(fout, network().converter().mins());
save_vector(fout, network().converter().maxs());
// save network data
basic_policy_type py;
py.network() = network().onet();
py.save();
}
bool load()
{
QFile file(PCA_DB_PATH);
if (!file.open(QIODevice::ReadOnly))
return false;
QTextStream fin(&file);
// read signature
QString sig;
fin >> sig;
if (sig != "pca") return false;
// load the weights
size_type vecs_n = network().pnet().vectors_size();
for (size_type i = 0; i < vecs_n; ++i)
network().pnet().vector(i) = (load_vector<vector_type>(fin));
// load converters
network().pnet().converter().avgs() = load_vector<vector_type>(fin);
network().pnet().converter().sds() = load_vector<vector_type>(fin);
network().converter().mins() = load_vector<vector_type>(fin);
network().converter().maxs() = load_vector<vector_type>(fin);
// load network data
basic_policy_type py;
bool ret = py.load();
network().onet() = py.network();
return ret;
}
QString validate(samples_type& sps)
{
validator().validate(network(), sps);
//return QObject::tr("pca mixed network:: mse:%1 erate:%2").arg(validator().mse()).arg(validator().erate());
return QObject::tr("mse:%1 rate:%2%").arg(validator().mse()).arg((1 - validator().erate()) * 100);
}
/// @}
};
// pca_mixed_policy_with_bayes_classifier
template<typename_param_k Py>
struct pca_mixed_policy_with_bayes_classifier
{
public:
typedef Py basic_policy_type;
typedef typename_type_k basic_policy_type::network_type basic_network_type;
typedef ghia_network<PCA_INPUT_N, PCA_OUTPUT_N> pca_type;
typedef pca_mixed_network<pca_type, basic_network_type> network_type;
typedef typename_type_k network_type::sample_type sample_type;
typedef typename_type_k network_type::samples_type samples_type;
typedef typename_type_k network_type::size_type size_type;
typedef typename_type_k network_type::float_type float_type;
typedef bayes_classifier<network_type> classifier_type;
typedef basic_classifier_validator<classifier_type> validator_type;
typedef typename_type_k pca_type::vector_type vector_type;
private:
network_type m_network;
validator_type m_validator;
classifier_type m_classifier;
public:
pca_mixed_policy_with_bayes_classifier()
: m_network(pca_type(), basic_policy_type().network())
, m_classifier(&m_network)
{
}
public:
network_type& network() { return m_network; }
network_type const& network() const { return m_network; }
validator_type& validator() { return m_validator; }
validator_type const& validator() const { return m_validator; }
classifier_type& classifier() { return m_classifier; }
classifier_type const& classifier() const { return m_classifier; }
/// \name Methods
/// @{
public:
void train(samples_type& sps)
{
// encode samples
encode_samples(sps);
// train
network().train(sps, TRAIN_COUNT, 10);
// init classifier
classifier().init(sps);
}
void run(sample_type& sp)
{
classifier().classify(sp);
decode_sample(sp);
}
void save()
{
QFile file(PCA_DB_PATH);
if (!file.open(QIODevice::WriteOnly | QIODevice::Text))
return;
QTextStream fout(&file);
// save signature
fout << "pca" << endl;
// save the weights
size_type vecs_n = network().pnet().vectors_size();
for (size_type i = 0; i < vecs_n; ++i)
save_vector(fout, network().pnet().vector(i));
// save converters
save_vector(fout, network().pnet().converter().avgs());
save_vector(fout, network().pnet().converter().sds());
save_vector(fout, network().converter().mins());
save_vector(fout, network().converter().maxs());
// save network data
basic_policy_type py;
py.network() = network().onet();
py.save();
}
bool load()
{
QFile file(PCA_DB_PATH);
if (!file.open(QIODevice::ReadOnly))
return false;
QTextStream fin(&file);
// read signature
QString sig;
fin >> sig;
if (sig != "pca") return false;
// load the weights
size_type vecs_n = network().pnet().vectors_size();
for (size_type i = 0; i < vecs_n; ++i)
network().pnet().vector(i) = (load_vector<vector_type>(fin));
// load converters
network().pnet().converter().avgs() = load_vector<vector_type>(fin);
network().pnet().converter().sds() = load_vector<vector_type>(fin);
network().converter().mins() = load_vector<vector_type>(fin);
network().converter().maxs() = load_vector<vector_type>(fin);
// load network data
basic_policy_type py;
bool ret = py.load();
network().onet() = py.network();
return ret;
}
void encode_samples(samples_type& sps)
{
/*size_type h[10];
h[0] = 0; //"0000";
h[1] = 3; // "0011";
h[2] = 6; // "0110";
h[3] = 12; // "1100";
h[4] = 9; // "1001";
h[5] = 5; // "0101";
h[6] = 10; // "1010";
h[7] = 15; // "1111";
h[8] = 1; // "0001";
h[9] = 14; // "1110";
for (size_type i = 0; i < sps.size(); ++i)
sps[i].dreal(h[sps[i].dreal()]);
h[0] = 0; //"0000000000";
h[1] = 31; // "0000011111";
h[2] = 62; // "0000111110";
h[3] = 124; // "0001111100";
h[4] = 248; // "0011111000";
h[5] = 496; // "0111110000";
h[6] = 992; // "1111100000";
h[7] = 220; // "0011011100";
h[8] = 803; // "1100100011";
h[9] = 1023; // "1111111111";
for (size_type i = 0; i < sps.size(); ++i)
sps[i].dreal(h[sps[i].dreal()]);*/
}
void decode_sample(sample_type& sp)
{
/*typedef typename_type_k hash_selector<size_type, size_type>::hash_type hash_type;
hash_type h;
h[0] = 0; //"0000000000";
h[31] = 1; // "0000011111";
h[62] = 2; // "0000111110";
h[124] = 3; // "0001111100";
h[248] = 4; // "0011111000";
h[496] = 5; // "0111110000";
h[992] = 6; // "1111100000";
h[220] = 7; // "0011011100";
h[803] = 8; // "1100100011";
h[1023] = 9; // "1111111111";
sp.doutput(h[sp.doutput()]);
h[0] = 0; //"0000";
h[3] = 3; // "0011";
h[6] = 6; // "0110";
h[12] = 12; // "1100";
h[9] = 9; // "1001";
h[5] = 5; // "0101";
h[10] = 10; // "1010";
h[15] = 15; // "1111";
h[1] = 1; // "0001";
h[14] = 14; // "1110";
sp.doutput(h[sp.doutput()]);*/
}
void init_classifier(samples_type& trained_sps)
{
// init classifier
m_classifier.init(trained_sps);
}
QString validate(samples_type& sps)
{
validator().validate(classifier(), sps);
//return QObject::tr("pca mixed network:: mse:%1 erate:%2").arg(validator().mse()).arg(validator().erate());
return QObject::tr("mse:%1 rate:%2%").arg(validator().mse()).arg((1 - validator().erate()) * 100);
}
/// @}
};
#endif
// network
template<typename_param_k Py>
class Network
{
public:
typedef Py policy_type;
typedef typename_type_k policy_type::network_type network_type;
typedef typename_type_k policy_type::sample_type sample_type;
typedef typename_type_k policy_type::samples_type samples_type;
private:
policy_type m_policy;
samples_type m_samples;
bool m_is_prepared;
public:
Network()
: m_policy()
, m_samples()
, m_is_prepared(false)
{
// save network data
if (policy().load())
m_is_prepared = true;
}
public:
samples_type& samples() { return m_samples; }
samples_type const& samples() const { return m_samples; }
policy_type& policy() { return m_policy; }
policy_type const& policy() const { return m_policy; }
bool is_prepared() const { return m_is_prepared; }
QString info() { return policy().validate(samples()); }
void add(int digit, QImage const& image)
{
if (digit < 0 || digit > 9) return ;
// add sample
samples().push_back(convert(digit, image));
}
void clear()
{
samples().clear();
}
int recognize(QImage const& image)
{
sample_type sp = convert(-1, image);
policy().run(sp);
return sp.doutput();
}
QString train()
{
// train
m_is_prepared = false;
policy().train(samples());
m_is_prepared = true;
// save network data
policy().save();
// validate
return policy().validate(samples());
}
private:
sample_type convert(int digit, QImage const& image) const
{
image_type img = ImageConvector::convect(image);
sample_type sp;
for (int px = 0; px < img.width(); ++px)
{
for (int py = 0; py < img.height(); ++py)
{
sp.set_binput(px * img.height() + py, img.at(px, py).is_black());
sp.dreal(digit);
}
}
return sp;
}
};
#endif // NETWORK_H
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化