神经网络课程的作业,一个简单的BP网络。准确率有点低,可能是我算法有点问题,100个训练数据,测试50个数据,只得80%正确。
BP类 //封装bp算法
package arithmetic;
public class BP {
private double[] P;
private double[] T;
private double[][] W1;
private double[][] W2;
private int n_a0;
private int n_a1;
private int n_a2;
private double[] B1;
private double[] B2;
private double[] a1;
private double[] a2;
private double[] q;
private double[] db1;
private double[] db2;
private double[][] dw1;
private double[][] dw2;
private double e;
private double r;
private double e0;
public BP(double[][] W1, double[][] W2, double[] B1, double[] B2) {
this.W1 = W1;
this.W2 = W2;
this.B1 = B1;
this.B2 = B2;
n_a0 = W1[0].length;
n_a1 = W1.length;
n_a2 = W2.length;
init();
}
public void setP(double[] P) {
this.P = P;
}
public void setT(double[] T) {
this.T = T;
}
private void init() {
a1 = new double[n_a1];
a2 = new double[n_a2];
r = 0.4;
e0 = 0.02;
q = new double[n_a2];
db2 = new double[n_a2];
dw2 = new double[n_a2][n_a1];
db1 = new double[n_a1];
dw1 = new double[n_a1][n_a0];
}
public void calA1() {
double temp = 0;
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
temp += W1[i][j] * P[j];
}
temp += B1[i];
a1[i] = F.f1(temp);
}
}
public double[] getA1() {
return a1;
}
public double[] getA2() {
return a2;
}
public void calA2() {
double temp = 0;
for (int k = 0; k < n_a2; k++) {
for (int i = 0; i < n_a1; i++) {
temp += W2[k][i] * a1[i];
}
temp += B2[k];
a2[k] = F.f2(temp);
}
}
public void calE() {
e = 0;
for (int k = 0; k < n_a2; k++) {
double ek = T[k] - a2[k];
e += ek * ek;
e /= 2;
}
}
public void calDb2() {
for (int k = 0; k < n_a2; k++) {
q[k] = (T[k] - a2[k]) * F.f2_1(a2[k]);
db2[k] = q[k] * r;
}
}
public void calDw2() {
for (int k = 0; k < n_a2; k++) {
for (int i = 0; i < n_a1; i++) {
dw2[k][i] = db2[k] * a1[i];
}
}
}
public void calDb1() {
for (int i = 0; i < n_a1; i++) {
db1[i] = 0;
for (int k = 0; k < n_a2; k++) {
db1[i] += q[k] * W2[k][i];
}
db1[i] *= r * F.f1_1(a1[i]);
}
}
public void calDw1() {
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
dw1[i][j] = db1[i] * P[j];
}
}
}
public void changeDb2() {
for (int i = 0; i < n_a2; i++) {
B2[i] += db2[i];
}
}
public void changeDw2() {
for (int i = 0; i < n_a2; i++) {
for (int j = 0; j < n_a1; j++) {
W2[i][j] += dw2[i][j];
}
}
}
public void changeDb1() {
for (int i = 0; i < n_a1; i++) {
B1[i] += db1[i];
}
}
public void changeDw1() {
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
W1[i][j] += dw1[i][j];
}
}
}
public void train(double[][] P, double[][] T) {
while(true){
boolean isChange = false;
for (int n = 0; n < P.length; n++) {
setP(P[n]);
setT(T[n]);
this.calA1();
this.calA2();
this.calE();
if (e < e0)
continue;
this.calDb2();
this.calDw2();
this.calDb1();
this.calDw1();
this.changeDb2();
this.changeDw2();
this.changeDb1();
this.changeDw1();
isChange = true;
// break;
}
if (!isChange) {
System.out.println("train succeed");
break;
}
}
}
public double[] divide(double[] p, double[] t) {
setP(p);
setT(t);
this.calA1();
this.calA2();
return a2;
}
public double getE(){
return e;
}
}
F类 //封装神经元函数
package arithmetic;
public class F {
public static double f1(double x){
return 1/(1+Math.exp(-1*x));
}
public static double f1_1(double y){
return y*(1-y);
}
public static double f2(double x){
return x;
}
public static double f2_1(double y){
return 1;
}
}
Controller类 读入训练数据和测试数据,并创建BP实例进行训练测试
package arithmetic;
import java.io.*;
import java.util.ArrayList;
import javax.swing.JFrame;
public class Controler {
private double[][] p_test;
private double[][] t_test;
private double[][] p_train;
private double[][] t_train;
private BP bp;
private JFrame viwer;
public Controler() throws IOException{
getTestData();
getTrainData();
double[][] w1 = new double[][] { { 0.2, 0.3, 0.4, 0.1 },
{ 0.3, 0.4, 0.2, 0.4 }, { 0.4, 0.8, 0.9, 0.3 }};
double[][] w2 = new double[][] { { 0.3, 0.6, 0.7 }, { 0.1, 0.3, 0.7 } };
double[] b1 = new double[] { 0.2, 0.4, 0.5 };
double[] b2 = new double[] { 0.1, 0.5 };
bp = new BP(w1,w2,b1,b2);
bp.train(p_train, t_train);
int a = 0;
for(int i =0;i<p_test.length;i++){
double[] a2 = bp.divide(p_test[i], t_test[i]);
int t0 = (int)(t_test[i][0])*2+(int)(t_test[i][1]);
int t1 = (int)(a2[0])*2+(int)(a2[1]);
boolean equals = t1==t0;
if(equals)a++;
System.out.println("expected:"+t0+"\t"+"output:"+t1+"\t"+equals);
}
a = (int)(a/50.0*100);
System.out.println(a);
}
private void getTestData() throws IOException{
String fileName = "testData.txt";
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
br.close();
e.printStackTrace();
}
ArrayList al = new ArrayList();
String s = null;
while((s=br.readLine())!=null){
al.add(s);
}
br.close();
p_test = new double[al.size()][4];
t_test = new double[al.size()][2];
double[] maxData = new double[]{0,0,0,0};
for(int i =0;i<al.size();i++){
String[] temp = al.get(i).toString().split(" ");;
for(int j =0;j<4;j++){
p_test[i][j] = Double.parseDouble(temp[j+1]);
if(p_test[i][j]>maxData[j])maxData[j] = p_test[i][j];
}
int d = Integer.parseInt(temp[0]);
switch (d){
case 0:
t_test[i][0] = 0;
t_test[i][1] = 0;
break;
case 1:
t_test[i][0] = 0;
t_test[i][1] = 1;
break;
case 2:
t_test[i][0] = 1;
t_test[i][1] = 0;
break;
default:
t_test[i][0] = 1;
t_test[i][1] = 1;
break;
}
}
for(int i =0;i<p_test.length;i++){
for(int j =0;j<4;j++){
p_test[i][j] /= maxData[j];
}
}
}
private void getTrainData() throws IOException{
String fileName = "trainData.txt";
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
br.close();
e.printStackTrace();
}
ArrayList al = new ArrayList();
String s = null;
while((s=br.readLine())!=null){
al.add(s);
}
p_train = new double[al.size()][4];
t_train = new double[al.size()][2];
double[] maxData = new double[]{0,0,0,0};
for(int i =0;i<al.size();i++){
String[] temp = al.get(i).toString().split(" ");;
for(int j =0;j<temp.length-1;j++){
p_train[i][j] = Double.parseDouble(temp[j+1]);
if(p_train[i][j]>maxData[j])maxData[j] = p_train[i][j];
}
int d = Integer.parseInt(temp[0]);
switch (d){
case 0:
t_train[i][0] = 0;
t_train[i][1] = 0;
break;
case 1:
t_train[i][0] = 0;
t_train[i][1] = 1;
break;
case 2:
t_train[i][0] = 1;
t_train[i][1] = 0;
break;
default:
t_train[i][0] = 1;
t_train[i][1] = 1;
break;
}
}
for(int i =0;i<p_train.length;i++){
for(int j =0;j<4;j++){
p_train[i][j] /= maxData[j];
}
}
}
public static void main(String[] args) throws Exception{
Controler ctr =new Controler();
}
}
分享到:
相关推荐
BP神经网络程序,java语言源代码,自己整理的
BP算法的神经网络的源代码, 可以根据向量建立网络,网络的训练结果和初始结构可以用XML保存和载入。 <br>其中 Compressor/TrainerWithDiagram.class , 是一个用于演示的训练器, 产生制定范围内的数,生成...
采用高效快速的粒子群算法对神经网络进行学习,提供所有源代码。语言,Java
典型的BP神经网络学习算法,以及内部讲解,包括神经元个数,层数详细说明具体该怎么设定,简单的事例,快速明白交叉熵函数在反向传播过程的作用
BP神经网络算法源代码,文件为JAVA语言编写的,编译环境为Eclipse
bp神经网络程序和java源代码源代码是eclipse的工程,所以若您使用的是eclipse,可以直接用eclipse的import将源代码的文件夹加入你的工程中
BP神经网络的java源代码,采用动量梯度下降法
Java基于BP神经网络的手写数字识别源代码+训练集 assets/inputHidden.csv 输入层到隐藏层的权重矩阵 assets/hiddenOutput.csv 隐藏层到输出层的矩阵 assets/train-images-idx3-ubyte 训练集图片 assets/train-labels...
多层前向神经网络(MLP)的源代码,具有如下特点: 1、定制网络结构,可以有多个隐含层。 2、隐层节点和输出层可以关联不同的激发函数,实现了线性、tanh、sigmoid三种激发函数。 3、训练方法暂时只实现了BP。代码...
1、资源内容:基于matlab编程的针对uci葡萄酒分类数据集的学习,主要的方法是BP和RBF+源代码+文档说明 2、代码特点:内含运行结果,不会运行可私信,参数化编程、参数可方便更改、代码编程思路清晰、注释明细,都...
1、资源内容:基于matlab-bp神经网络实现的数字图像识别+源代码+文档说明 2、代码特点:内含运行结果,不会运行可私信,参数化编程、参数可方便更改、代码编程思路清晰、注释明细,都经过测试运行成功,功能ok的情况...
JAVA 写的BP神经网络算法,实现了sin三角函数的模拟和数字识别的模拟,有源代码和文档说明
bp神经网络程序和java源代码开发包,适合做网页开发的人
1、资源内容:机器学习大作业,人脸识别-运用BP神经网络实现性别检测+源代码+文档说明+pdf+报告 2、代码特点:内含运行结果,不会运行可私信,参数化编程、参数可方便更改、代码编程思路清晰、注释明细,都经过测试...
基于JAVA毕业设计-JAVA网络通信系统的研究与开发(论文+源代码+开题报告).rar 1.本课题的研究意义,国内外研究现状、水平和发展趋势 网络通信在当今信息社会中起着不可或缺的作用 ,人们可以利用网络通信技术进行即时...
在信息话识,把图片识别成文字或者字符,代替人工处理,是现在大家都在进行的研究,本程序通过神经网络把图片的文字转换为可以使用的文字,实现了文字的识别。
基于BP算法和遗传算法建立武汉市空气质量指数的预测模型,以武汉市8个监测站的1年的空气质量数据为训练数据进行神经网络的建模,近一个月的数据作为测试数据进行模型的准确性测试,平均准确率在75%左右 - 不懂运行,...
对于有一定基础或热衷于研究的人来说,可以在这些基础代码上进行修改和扩展,实现其他功能。【沟通交流】:有任何使用上的问题,欢迎随时与博主沟通,博主会及时解答。鼓励下载和使用,并欢迎大家互相学习,共同进步...
对于有一定基础或热衷于研究的人来说,可以在这些基础代码上进行修改和扩展,实现其他功能。【沟通交流】:有任何使用上的问题,欢迎随时与博主沟通,博主会及时解答。鼓励下载和使用,并欢迎大家互相学习,共同进步...