代码拉取完成,页面将自动刷新
package rexen;
import java.util.Arrays;
import java.util.Vector;
class Data
{
Vector<Double> x = new Vector<Double>(); //输入数据
Vector<Double> y = new Vector<Double>(); //输出数据
};
public class BPNN {
final int LAYER = 3; //三层神经网络
final int NUM = 10; //每层的最多节点数
float A = (float) 30.0;
float B = (float) 10.0; //A和B是S型函数的参数
int ITERS = 1000; //最大训练次数
float ETA_W = (float) 0.0035; //权值调整率
float ETA_B = (float) 0.001; //阀值调整率
float ERROR = (float) 0.002; //单个样本允许的误差
float ACCU = (float) 0.005; //每次迭代允许的误差
int in_num; //输入层节点数
int ou_num; //输出层节点数
int hd_num; //隐含层节点数
Double[][][] w =new Double[LAYER][NUM][NUM]; //BP网络的权值
Double[][] b = new Double[LAYER][NUM]; //BP网络节点的阀值
Double[][] x= new Double[LAYER][NUM]; //每个神经元的值经S型函数转化后的输出值,输入层就为原值
Double[][] d= new Double[LAYER][NUM]; //记录delta学习规则中delta的值
Vector<Data> data;
//获取训练所有样本数据
void GetData(Vector<Data> _data)
{
data = _data;
}
//开始进行训练
void Train()
{
System.out.printf("Begin to train BP NetWork!\n");
GetNums();
InitNetWork();
int num = data.size();
for(int iter = 0; iter <= ITERS; iter++)
{
for(int cnt = 0; cnt < num; cnt++)
{
//第一层输入节点赋值
for(int i = 0; i < in_num; i++)
x[0][i] = data.get(cnt).x.get(i);
while(true)
{
ForwardTransfer();
if(GetError(cnt) < ERROR) //如果误差比较小,则针对单个样本跳出循环
break;
ReverseTransfer(cnt);
}
}
System.out.printf("This is the %d th trainning NetWork !\n", iter);
Double accu = GetAccu();
System.out.printf("All Samples Accuracy is " + accu);
if(accu < ACCU) break;
}
System.out.printf("The BP NetWork train End!\n");
}
//根据训练好的网络来预测输出值
Vector<Double> ForeCast(Vector<Double> data)
{
int n = data.size();
assert(n == in_num);
for(int i = 0; i < in_num; i++)
x[0][i] = data.get(i);
ForwardTransfer();
Vector<Double> v = new Vector<Double>();
for(int i = 0; i < ou_num; i++)
v.add(x[2][i]);
return v;
}
//获取网络节点数
void GetNums()
{
in_num = data.get(0).x.size(); //获取输入层节点数
ou_num = data.get(0).y.size(); //获取输出层节点数
hd_num = (int)Math.sqrt((in_num + ou_num) * 1.0) + 5; //获取隐含层节点数
if(hd_num > NUM) hd_num = NUM; //隐含层数目不能超过最大设置
}
//初始化网络
void InitNetWork()
{
for(int i = 0; i < LAYER; i++){
for(int j = 0; j < NUM; j++){
for(int k = 0; k < NUM; k++){
w[i][j][k] = 0.0;
}
}
}
for(int i = 0; i < LAYER; i++){
for(int j = 0; j < NUM; j++){
b[i][j] = 0.0;
}
}
}
//工作信号正向传递子过程
void ForwardTransfer()
{
//计算隐含层各个节点的输出值
for(int j = 0; j < hd_num; j++)
{
Double t = 0.0;
for(int i = 0; i < in_num; i++)
t += w[1][i][j] * x[0][i];
t += b[1][j];
x[1][j] = Sigmoid(t);
}
//计算输出层各节点的输出值
for(int j = 0; j < ou_num; j++)
{
Double t = (double) 0;
for(int i = 0; i < hd_num; i++)
t += w[2][i][j] * x[1][i];
t += b[2][j];
x[2][j] = Sigmoid(t);
}
}
//计算单个样本的误差
double GetError(int cnt)
{
Double ans = (double) 0;
for(int i = 0; i < ou_num; i++)
ans += 0.5 * (x[2][i] - data.get(cnt).y.get(i)) * (x[2][i] - data.get(cnt).y.get(i));
return ans;
}
//误差信号反向传递子过程
void ReverseTransfer(int cnt)
{
CalcDelta(cnt);
UpdateNetWork();
}
//计算所有样本的精度
double GetAccu()
{
Double ans = (double) 0;
int num = data.size();
for(int i = 0; i < num; i++)
{
int m = data.get(i).x.size();
for(int j = 0; j < m; j++)
x[0][j] = data.get(i).x.get(j);
ForwardTransfer();
int n = data.get(i).y.size();
for(int j = 0; j < n; j++)
ans += 0.5 * (x[2][j] - data.get(i).y.get(j)) * (x[2][j] - data.get(i).y.get(j));
}
return ans / num;
}
//计算调整量
void CalcDelta(int cnt)
{
//计算输出层的delta值
for(int i = 0; i < ou_num; i++)
d[2][i] = (x[2][i] - data.get(cnt).y.get(i)) * x[2][i] * (A - x[2][i]) / (A * B);
//计算隐含层的delta值
for(int i = 0; i < hd_num; i++)
{
Double t = (double) 0;
for(int j = 0; j < ou_num; j++)
t += w[2][i][j] * d[2][j];
d[1][i] = t * x[1][i] * (A - x[1][i]) / (A * B);
}
}
//根据计算出的调整量对BP网络进行调整
void UpdateNetWork()
{
//隐含层和输出层之间权值和阀值调整
for(int i = 0; i < hd_num; i++)
{
for(int j = 0; j < ou_num; j++)
w[2][i][j] -= ETA_W * d[2][j] * x[1][i];
}
for(int i = 0; i < ou_num; i++)
b[2][i] -= ETA_B * d[2][i];
//输入层和隐含层之间权值和阀值调整
for(int i = 0; i < in_num; i++)
{
for(int j = 0; j < hd_num; j++)
w[1][i][j] -= ETA_W * d[1][j] * x[0][i];
}
for(int i = 0; i < hd_num; i++)
b[1][i] -= ETA_B * d[1][i];
}
//计算Sigmoid函数的值
Double Sigmoid(double x)
{
return A / (1 + Math.exp(-x / B));
}
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。