`
hzxdark
  • 浏览: 77182 次
社区版块
存档分类
最新评论

BP网络JAVA版源代码

阅读更多

 

神经网络课程的作业,一个简单的BP网络。准确率有点低,可能是我算法有点问题,100个训练数据,测试50个数据,只得80%正确。

 

 

BP类  //封装bp算法

package arithmetic;

public class BP {
 private double[] P;

 private double[] T;

 private double[][] W1;

 private double[][] W2;

 private int n_a0;

 private int n_a1;

 private int n_a2;

 private double[] B1;

 private double[] B2;

 private double[] a1;

 private double[] a2;

 private double[] q;

 private double[] db1;

 private double[] db2;

 private double[][] dw1;

 private double[][] dw2;

 private double e;

 private double r;

 private double e0;

 public BP(double[][] W1, double[][] W2, double[] B1, double[] B2) {
  this.W1 = W1;
  this.W2 = W2;
  this.B1 = B1;
  this.B2 = B2;
  n_a0 = W1[0].length;
  n_a1 = W1.length;
  n_a2 = W2.length;
  init();
 }

 public void setP(double[] P) {
  this.P = P;
 }

 public void setT(double[] T) {
  this.T = T;
 }

 private void init() {
  a1 = new double[n_a1];
  a2 = new double[n_a2];
  r = 0.4;
  e0 = 0.02;
  q = new double[n_a2];
  db2 = new double[n_a2];
  dw2 = new double[n_a2][n_a1];
  db1 = new double[n_a1];
  dw1 = new double[n_a1][n_a0];
 }

 public void calA1() {
  double temp = 0;
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    temp += W1[i][j] * P[j];
   }
   temp += B1[i];
   a1[i] = F.f1(temp);
  }
 }

 public double[] getA1() {
  return a1;
 }

 public double[] getA2() {
  return a2;
 }

 public void calA2() {
  double temp = 0;
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    temp += W2[k][i] * a1[i];
   }
   temp += B2[k];
   a2[k] = F.f2(temp);
  }
 }

 public void calE() {
  e = 0;
  for (int k = 0; k < n_a2; k++) {
   double ek = T[k] - a2[k];
   e += ek * ek;
   e /= 2;
  }
 }

 public void calDb2() {
  for (int k = 0; k < n_a2; k++) {
   q[k] = (T[k] - a2[k]) * F.f2_1(a2[k]);
   db2[k] = q[k] * r;
  }
 }

 public void calDw2() {
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    dw2[k][i] = db2[k] * a1[i];
   }
  }
 }

 public void calDb1() {
  for (int i = 0; i < n_a1; i++) {
   db1[i] = 0;
   for (int k = 0; k < n_a2; k++) {
    db1[i] += q[k] * W2[k][i];
   }
   db1[i] *= r * F.f1_1(a1[i]);
  }
 }

 public void calDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    dw1[i][j] = db1[i] * P[j];
   }
  }
 }

 public void changeDb2() {
  for (int i = 0; i < n_a2; i++) {
   B2[i] += db2[i];
  }
 }

 public void changeDw2() {
  for (int i = 0; i < n_a2; i++) {
   for (int j = 0; j < n_a1; j++) {
    W2[i][j] += dw2[i][j];
   }
  }
 }

 public void changeDb1() {
  for (int i = 0; i < n_a1; i++) {
   B1[i] += db1[i];
  }
 }

 public void changeDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    W1[i][j] += dw1[i][j];
   }
  }
 }

 public void train(double[][] P, double[][] T) {
  while(true){
   boolean isChange = false;
   for (int n = 0; n < P.length; n++) {
    setP(P[n]);
    setT(T[n]);
    this.calA1();
    this.calA2();
    this.calE();
    if (e < e0)
     continue;
    this.calDb2();
    this.calDw2();
    this.calDb1();
    this.calDw1();
    this.changeDb2();
    this.changeDw2();
    this.changeDb1();
    this.changeDw1();
    isChange = true;
   // break;

   }
   if (!isChange) {
    System.out.println("train succeed");
    break;
   }
  }
 }

 public double[] divide(double[] p, double[] t) {
  setP(p);
  setT(t);
  this.calA1();
  this.calA2();
  return a2;
 }
 
 public double getE(){
  return e;
 }
}

F类 //封装神经元函数

package arithmetic;

public class F {
 public static double f1(double x){
  return 1/(1+Math.exp(-1*x));
 }
 public static double f1_1(double y){
  return y*(1-y);
}
 public static double f2(double x){
  return x;
 }
 public static double f2_1(double y){
  return 1;
 }
}

Controller类 读入训练数据和测试数据,并创建BP实例进行训练测试

package arithmetic;

import java.io.*;
import java.util.ArrayList;

import javax.swing.JFrame;

public class Controler {
 private double[][] p_test;
 private double[][] t_test;
 private double[][] p_train;
 private double[][] t_train; 
 
 private BP bp;
 private JFrame viwer;
 
 
 public Controler() throws IOException{
  getTestData();
  getTrainData();
  double[][] w1 = new double[][] { { 0.2, 0.3, 0.4, 0.1 },
    { 0.3, 0.4, 0.2, 0.4 }, { 0.4, 0.8, 0.9, 0.3 }};
  double[][] w2 = new double[][] { { 0.3, 0.6, 0.7 }, { 0.1, 0.3, 0.7 } };
  double[] b1 = new double[] { 0.2, 0.4, 0.5 };
  double[] b2 = new double[] { 0.1, 0.5 };
  bp = new BP(w1,w2,b1,b2);
  bp.train(p_train, t_train);
  int a = 0;
  for(int i =0;i<p_test.length;i++){
   double[] a2 = bp.divide(p_test[i], t_test[i]);  
   int t0 = (int)(t_test[i][0])*2+(int)(t_test[i][1]);
   int t1 = (int)(a2[0])*2+(int)(a2[1]);
   boolean equals = t1==t0;
   if(equals)a++;
   System.out.println("expected:"+t0+"\t"+"output:"+t1+"\t"+equals);
  }
  a = (int)(a/50.0*100);
  System.out.println(a);
 }
 
 private void getTestData() throws IOException{
  String fileName = "testData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  br.close();
  p_test = new double[al.size()][4];
  t_test = new double[al.size()][2];
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<4;j++){
    p_test[i][j] = Double.parseDouble(temp[j+1]);
    if(p_test[i][j]>maxData[j])maxData[j] = p_test[i][j];
   }
   int d = Integer.parseInt(temp[0]);
   switch (d){
   case 0:
    t_test[i][0] = 0;
    t_test[i][1] = 0;
    break;
   case 1:
    t_test[i][0] = 0;
    t_test[i][1] = 1;
    break;
   case 2:
    t_test[i][0] = 1;
    t_test[i][1] = 0;
    break;
   default:
    t_test[i][0] = 1;
    t_test[i][1] = 1;
    break;
   }
  }
  for(int i =0;i<p_test.length;i++){
   for(int j =0;j<4;j++){
    p_test[i][j] /= maxData[j];
   }
  }
 }
 private void getTrainData() throws IOException{
  String fileName = "trainData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  p_train = new double[al.size()][4];
  t_train = new double[al.size()][2];
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<temp.length-1;j++){
    p_train[i][j] = Double.parseDouble(temp[j+1]);
    if(p_train[i][j]>maxData[j])maxData[j] = p_train[i][j];
   }
   int d = Integer.parseInt(temp[0]);
   switch (d){
   case 0:
    t_train[i][0] = 0;
    t_train[i][1] = 0;
    break;
   case 1:
    t_train[i][0] = 0;
    t_train[i][1] = 1;
    break;
   case 2:
    t_train[i][0] = 1;
    t_train[i][1] = 0;
    break;
   default:
    t_train[i][0] = 1;
    t_train[i][1] = 1;
    break;
   }
  }
  for(int i =0;i<p_train.length;i++){
   for(int j =0;j<4;j++){
    p_train[i][j] /= maxData[j];
   }
  }
 }
 public static void main(String[] args) throws Exception{
  Controler ctr =new Controler();
  
 }
}

 

分享到:
评论
4 楼 xiaolu_yatou 2012-07-25  
乱 
3 楼 张空空 2010-06-29  
没有注释哦
2 楼 ydsakyclguozi 2009-06-30  
楼主,那你发个什么劲儿啊,改好了再发啊!
1 楼 mating 2008-06-20  
我用了一下怎么有问题啊???

相关推荐

Global site tag (gtag.js) - Google Analytics