`
cocoIT
  • 浏览: 48770 次
  • 性别: Icon_minigender_1
  • 来自: 福建
文章分类
社区版块
存档分类
最新评论

bp神经网络算法的java实现

阅读更多
package ann;
public class Node implements java.io.Serializable,Cloneable {     
      public double activation;     
      public double threshold;     
      public double weights[];     
      public double detweightslast[];     
      public double detthresholdlast;     
      public double error;     
      public int numOfweights;     
      public double amultinum;     
         
      public ArrayList arWeight=new ArrayList();     
         
      public Node() {     
        activation = 0;     
        error = 0;     
        threshold = 0;     
        amultinum = 1;     
      }     
         
      public Node(int numOfweights0) {     
        amultinum = 1;     
        numOfweights = numOfweights0;     
        weights = new double[numOfweights];     
        detweightslast = new double[numOfweights];     
        detthresholdlast = 0;     
        error = 0;     
        int i;     
        for (i = 0; i < numOfweights; i++) {     
          weights[i] = (2 * Math.random() - 1) * amultinum;     
          detweightslast[i] = 0;     
        }     
        threshold = (2 * Math.random() - 1) * amultinum;     
      }     
         
      public Node(double act, double thr, int numOfweights0) {     
        amultinum = 1;     
        numOfweights = numOfweights0;     
        activation = act;     
        threshold = thr;     
        weights = new double[numOfweights];     
        detweightslast = new double[numOfweights];     
      }     
         
      public void setWeight(ArrayList weight){     
          weights = new double[weight.size()];     
          for(int i=0;i<weight.size();i++){     
              weights[i] = ((Double)weight.get(i)).doubleValue();     
          }     
      }     
         
         
    }

package ann;     
         
    import ann.*;     
         
    public class Nnetwork implements java.io.Serializable,Cloneable{     
      public int n_input;     
      public int n_output;     
      public int n_layer;     
      public Layer hidelayer[];     
      public Layer input, output;     
      boolean iftrain[];     
      public double LH=-10;      
      public double LO=-10;      
         
      public double desiredOutputs[];      
      public double a=0.2;     
      public int connecttype;     
      public double total_error, total_error_one_circle_all;     
      public double error_compared_to_tolerance;     
      double total_error_one_circle[];     
      public int trainnum;     
         
      public Nnetwork() {     
      }     
         
      public Nnetwork(int inp, int hide[], int outp, int hidenum, int connecttype0) {     
        connecttype = connecttype0;     
        int i, j;     
        n_input = inp;     
        n_output = outp;     
        total_error_one_circle = new double[outp];     
        desiredOutputs = new double[outp];     
        output = new Layer(n_output);     
        for (i = 0; i < n_output; i++) {     
          output.nodes[i] = new Node(0);     
        }     
        n_layer = hidenum;     
        hidelayer = new Layer[n_layer];     
        for (i = n_layer - 1; i >= 0; i--) {     
          hidelayer[i] = new Layer(hide[i]);     
          for (j = 0; j < hidelayer[i].N_NODES; j++) {     
            if (i == n_layer - 1) {     
              hidelayer[i].nodes[j] = new Node(outp);     
            }     
            else {     
              hidelayer[i].nodes[j] = new Node(hidelayer[i + 1].N_NODES);     
            }     
          }     
        }     
        input = new Layer(n_input);     
        for (i = 0; i < n_input; i++) {     
          input.nodes[i] = new Node(hidelayer[0].N_NODES);     
        }     
      }     
         
      void FirstTimeSettings() {     
        for (int i = 0; i < n_layer; i++) {     
          int j;     
          for (j = 0; j < hidelayer[i].N_NODES; j++) {     
            hidelayer[i].nodes[j].threshold = (2 * Math.random() - 1) *     
                hidelayer[i].nodes[j].amultinum;     
          }     
        }     
        for (int i = 0; i < n_output; i++) {     
          output.nodes[i].threshold = (2 * Math.random() - 1) *     
              output.nodes[i].amultinum;     
        }     
      }     
         
      void BeforeTraining(double inp[], double outp[]) {     
        int i;     
        for (i = 0; i < n_input; i++) {     
          input.nodes[i].activation = inp[i];     
        }     
        for (i = 0; i < n_output; i++) {     
          desiredOutputs[i] = outp[i];     
        }     
      }     
         
      public void Calc_Activation(double result[]) {     
        int i, j, ci;     
        for (i = 0; i < n_layer; i++) {     
          if (i == 0) {     
            for (j = 0; j < hidelayer[i].N_NODES; j++) {     
              hidelayer[i].nodes[j].activation = 0;     
              for (ci = 0; ci < n_input; ci++) {     
                hidelayer[i].nodes[j].activation += input.nodes[ci].activation *input.nodes[ci].weights[j];     
              }     
              hidelayer[i].nodes[j].activation += hidelayer[i].nodes[j].threshold;     
              hidelayer[i].nodes[j].activation = activefun(hidelayer[i].nodes[j].activation);     
            }     
          }     
          else {     
            for (j = 0; j < hidelayer[i].N_NODES; j++) {     
              hidelayer[i].nodes[j].activation = 0;     
              for (ci = 0; ci < hidelayer[i - 1].N_NODES; ci++) {     
                hidelayer[i].nodes[j].activation += hidelayer[i -1].nodes[ci].activation * hidelayer[i - 1].nodes[ci].weights[j];     
              }     
              hidelayer[i].nodes[j].activation += hidelayer[i].nodes[j].threshold;     
              hidelayer[i].nodes[j].activation = activefun(hidelayer[i].nodes[j].activation);     
            }     
          }     
        }     
        for (j = 0; j < output.N_NODES; j++) {     
          output.nodes[j].activation = 0;     
          for (ci = 0; ci < hidelayer[n_layer - 1].N_NODES; ci++) {     
            output.nodes[j].activation += hidelayer[n_layer -1].nodes[ci].activation * hidelayer[n_layer -1].nodes[ci].weights[j];     
          }     
          output.nodes[j].activation += output.nodes[j].threshold;     
          output.nodes[j].activation = activefun(output.nodes[j].activation);     
        }     
        for (i = 0; i < n_output; i++) {     
          result[i] = output.nodes[i].activation;     
        }     
      }     
         
      void Calc_error_output() {     
        for (int x = 0; x < n_output; x++)     
        //output.nodes[x].error = output.nodes[x].activation * (1 - output.nodes[x].activation ) * (desiredOutputs[x] - output.nodes[x].activation );     
        {     
          output.nodes[x].error += (output.nodes[x].activation - desiredOutputs[x]);     
          output.nodes[x].error *= difactivefun(output.nodes[x].activation);     
        }     
      }     
         
      void Calc_error_hidden() {     
        int j, i;     
        for (j = 0; j < hidelayer[n_layer - 1].N_NODES; j++) {     
          for (int x = 0; x < n_output; x++) {     
            hidelayer[n_layer - 1].nodes[j].error += hidelayer[n_layer -     
                1].nodes[j].weights[x] * output.nodes[x].error;     
          }     
          hidelayer[n_layer -     
              1].nodes[j].error *=     
              difactivefun(hidelayer[n_layer - 1].nodes[j].activation);     
        }     
         
        for (i = n_layer - 2; i >= 0; i--) {     
          for (j = 0; j < hidelayer[i].N_NODES; j++) {     
            for (int x = 0; x < hidelayer[i + 1].N_NODES; x++) {     
              hidelayer[i].nodes[j].error += hidelayer[i].nodes[j].weights[x] *     
                  hidelayer[i + 1].nodes[x].error;     
            }     
            hidelayer[i].nodes[j].error *=     
                difactivefun(hidelayer[i].nodes[j].activation);     
          }     
        }     
         
      }     
         
      void Calc_new_Thresholds() {     
        int i;     
        // computing the thresholds for next itration for hidden layer     
        for (i = 0; i < n_layer; i++) {     
          for (int x = 0; x < hidelayer[i].N_NODES; x++) {     
            double det = a * hidelayer[i].nodes[x].detthresholdlast +     
                hidelayer[i].nodes[x].error * LH;     
            hidelayer[i].nodes[x].detthresholdlast = det;     
            hidelayer[i].nodes[x].threshold += det;     
          }     
        }     
        for (int y = 0; y < output.N_NODES; y++) {     
          double det = a * output.nodes[y].detthresholdlast +     
              output.nodes[y].error * LO;     
          output.nodes[y].detthresholdlast = det;     
          output.nodes[y].threshold += det;     
        }     
         
      }     
         
      void Calc_new_weights_in_hidden() {     
         
        int i, j;     
        double temp = 0.0f;     
        for (j = 0; j < hidelayer[n_layer - 1].N_NODES; j++) {     
          temp = hidelayer[n_layer - 1].nodes[j].activation * LO;     
          for (int y = 0; y < n_output; y++) {     
            double det = a * hidelayer[n_layer - 1].nodes[j].detweightslast[y] +     
                temp * output.nodes[y].error;     
            hidelayer[n_layer - 1].nodes[j].detweightslast[y] = det;     
            hidelayer[n_layer - 1].nodes[j].weights[y] += det;     
          }     
         
        }     
         
        for (i = 0; i < n_layer - 1; i++) {     
          for (j = 0; j < hidelayer[i].N_NODES; j++) {     
            temp = hidelayer[i].nodes[j].activation * LH;     
            for (int y = 0; y < hidelayer[i + 1].N_NODES; y++) {     
              double det = a * hidelayer[i].nodes[j].detweightslast[y] +     
                  temp * hidelayer[i + 1].nodes[y].error;     
              hidelayer[i].nodes[j].detweightslast[y] = det;     
              hidelayer[i].nodes[j].weights[y] += det;     
            }     
         
          }     
        }     
         
      }     
         
      void Calc_new_weights_in_input() {     
        int j;     
        double temp = 0.0f;     
        for (j = 0; j < input.N_NODES; j++) {     
          temp = input.nodes[j].activation * LH;     
          for (int y = 0; y < hidelayer[0].N_NODES; y++) {     
            double det = a * input.nodes[j].detweightslast[y] +     
                temp * hidelayer[0].nodes[y].error;     
            input.nodes[j].detweightslast[y] = det;     
            input.nodes[j].weights[y] += det;     
          }     
        }     
      }     
         
      double Calc_total_error_in_pattern() {     
        double temp = 0.0;     
        for (int x = 0; x < n_output; x++) {     
    [x].activation - desiredOutputs[x]);     
           continue;     
         
         temp += Math.pow((output.nodes[x].activation - desiredOutputs[x]), 2);     
          total_error_one_circle[x] += Math.pow((output.nodes[x].activation - desiredOutputs[x]), 2);     
        }     
        total_error = temp;      
        total_error_one_circle_all += total_error;      
        return temp;     
      }     
         
      void reset_error() {     
        for (int i = 0; i < n_input; i++) {     
          input.nodes[i].error = 0;     
        }     
        for (int i = 0; i < n_output; i++) {     
          output.nodes[i].error = 0;     
        }     
        for (int i = 0; i < n_layer; i++) {     
          for (int j = 0; j < hidelayer[i].N_NODES; j++) {     
            hidelayer[i].nodes[j].error = 0;     
          }     
        }     
      }     
         
      void reset_total_error() {     
        total_error_one_circle_all = 0;     
        for (int x = 0; x < n_output; x++) {     
          total_error_one_circle[x] = 0;     
        }     
      }     
         
      void Training_for_one_pattern(double result[]) {     
        Calc_Activation(result);     
        Calc_error_output();     
        Calc_error_hidden();     
        Calc_new_Thresholds();     
        Calc_new_weights_in_hidden();     
        Calc_new_weights_in_input();     
      }     
         
      public void Training(double inputs[][], double outputs[][], int num, boolean ifresort) {     
         
        clearlastdetweight();     
        iftrain = new boolean[num];     
        setiftrain(inputs, outputs, num, iftrain);     
        int neworder[] = new int[num];     
        sortrandom(neworder, num, ifresort);     
        reset_total_error();     
        for (int k = 0; k < num; k++) {     
          int i = neworder[k];     
          if (iftrain[i]) {     
    //        System.out.println("k="+k+",i="+i+",iftrain[i]="+iftrain[i]);     
            reset_error();     
            BeforeTraining(inputs[i], outputs[i]);     
            double tmp[] = new double[output.N_NODES];     
            Calc_Activation(tmp);     
            Calc_error_output();     
            Calc_total_error_in_pattern();     
    //        System.out.println(k+"  "+tmp[0]+"  "+outputs[i][0]+"  "+output.nodes[0].activation +"  "+ desiredOutputs[0]+"  "+total_error_one_circle_all);     
            Calc_error_hidden();     
            Calc_new_Thresholds();     
            Calc_new_weights_in_hidden();     
            Calc_new_weights_in_input();     
          }     
        }     
      }     
         
      void TrainingAll(double inputs[][], double outputs[][], int num,boolean ifresort) {     
        clearlastdetweight();     
        iftrain = new boolean[num];     
        setiftrain(inputs, outputs, num, iftrain);     
        int neworder[] = new int[num];     
        sortrandom(neworder, num, ifresort);     
        reset_total_error();     
        reset_error();     
        for (int k = 0; k < num; k++) {     
          int i = neworder[k];     
          if (iftrain[i]) {     
            BeforeTraining(inputs[i], outputs[i]);     
            double tmp[] = new double[output.N_NODES];     
            Calc_Activation(tmp);     
            Calc_error_output();     
            Calc_total_error_in_pattern();     
          }     
        }     
        Calc_error_hidden();     
        Calc_new_Thresholds();     
        Calc_new_weights_in_hidden();     
        Calc_new_weights_in_input();     
      }     
         
      void clearlastdetweight() {     
        for (int i = 0; i < n_input; i++) {     
          input.nodes[i].detthresholdlast = 0;     
          for (int j = 0; j < input.nodes[i].numOfweights; j++) {     
            input.nodes[i].detweightslast[j] = 0;     
          }     
        }     
        for (int i = 0; i < n_output; i++) {     
          output.nodes[i].detthresholdlast = 0;     
          for (int j = 0; j < output.nodes[i].numOfweights; j++) {     
            output.nodes[i].detweightslast[j] = 0;     
          }     
        }     
        for (int k = 0; k < n_layer; k++) {     
          for (int i = 0; i < hidelayer[k].N_NODES; i++) {     
            hidelayer[k].nodes[i].detthresholdlast = 0;     
            for (int j = 0; j < hidelayer[k].nodes[i].numOfweights; j++) {     
              hidelayer[k].nodes[i].detweightslast[j] = 0;     
            }     
          }     
        }     
      }     
         
      void sortrandom(int neworder[], int num, boolean ifresort) {     
        for (int i = 0; i < num; i++) {     
          neworder[i] = i;     
        }     
        if (ifresort) {     
          for (int i = 0; i < num; i++) {     
            int pos = (int) (Math.random() * (num - i)) + i;     
            int tmp = neworder[pos];     
            neworder[pos] = neworder[i];     
            neworder[i] = tmp;     
          }     
        }     
      }     
         
      int setiftrain(double inputs[][], double outputs[][], int num,boolean iftrain[]) {     
        for (int i = 0; i < num; i++) {     
          iftrain[i] = true;     
          if (outputs[i][0] <= 0) {     
            iftrain[i] = false;     
          }     
        }     
        for (int i = 0; i < num; i++) {     
          for (int j = 0; j < num; j++) {     
            if (i != j) {     
              boolean ifsame = true;     
              for (int k = 0; k < n_input; k++) {     
                if (inputs[i][k] != inputs[j][k]) {     
                  ifsame = false;     
                  break;     
                }     
              }     
              if (ifsame) {     
                iftrain[i] = false;     
              }     
            }     
            if (iftrain[i] == false) {     
              break;     
            }     
          }     
        }     
        trainnum = 0;     
        for (int i = 0; i < num; i++) {     
          if (iftrain[i]) {     
            trainnum++;     
          }     
        }     
        return trainnum;     
      }     
         
      double sigmoid(double x) {     
        return 1 / (1 + Math.exp( -x));     
      }     
         
      double difsigmoid(double x) {     
        return x - x * x;     
         
      }     
         
      double tanh(double x) {     
         
        return (1 - Math.exp( -x)) / (1 + Math.exp( -x));     
      }     
         
      double diftanh(double x) {     
        return (1 - x * x) / 2;     
         
      }     
         
      double activefun(double x) {     
         
        if (connecttype == 0) {     
          return sigmoid(x);     
        }     
        else if (connecttype == 1) {     
          return tanh(x);     
        }     
        else {     
          return 0;     
        }     
         
      }     
         
      double difactivefun(double x) {     
        if (connecttype == 0) {     
          return difsigmoid(x);     
        }     
        else if (connecttype == 1) {     
          return diftanh(x);     
        }     
        else {     
          return 0;     
        }     
         
      }     
         
    }

package ann;     
         
         
    public class Layer implements java.io.Serializable,Cloneable{     
      public int N_NODES;     
      public Node nodes[];     
      public Layer() {     
      }     
      public Layer(int n)     
       {     
            N_NODES=n;     
            nodes=new Node[N_NODES];     
        }     
         
         
    }


分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics