`
zhc0822
  • 浏览: 228732 次
  • 性别: Icon_minigender_1
  • 来自: 宝仔的奇幻城堡
社区版块
存档分类
最新评论

BP神经网络的Java实现

阅读更多

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

 

package ml;

import java.util.Random;

/**
 * BPNN.
 * 
 * @author RenaQiu
 * 
 */
public class BP {
	/**
	 * input vector.
	 */
	private final double[] input;
	/**
	 * hidden layer.
	 */
	private final double[] hidden;
	/**
	 * output layer.
	 */
	private final double[] output;
	/**
	 * target.
	 */
	private final double[] target;

	/**
	 * delta vector of the hidden layer .
	 */
	private final double[] hidDelta;
	/**
	 * output layer of the output layer.
	 */
	private final double[] optDelta;

	/**
	 * learning rate.
	 */
	private final double eta;
	/**
	 * momentum.
	 */
	private final double momentum;

	/**
	 * weight matrix from input layer to hidden layer.
	 */
	private final double[][] iptHidWeights;
	/**
	 * weight matrix from hidden layer to output layer.
	 */
	private final double[][] hidOptWeights;

	/**
	 * previous weight update.
	 */
	private final double[][] iptHidPrevUptWeights;
	/**
	 * previous weight update.
	 */
	private final double[][] hidOptPrevUptWeights;

	public double optErrSum = 0d;

	public double hidErrSum = 0d;

	private final Random random;

	/**
	 * Constructor.
	 * <p>
	 * <strong>Note:</strong> The capacity of each layer will be the parameter
	 * plus 1. The additional unit is used for smoothness.
	 * </p>
	 * 
	 * @param inputSize
	 * @param hiddenSize
	 * @param outputSize
	 * @param eta
	 * @param momentum
	 * @param epoch
	 */
	public BP(int inputSize, int hiddenSize, int outputSize, double eta,
			double momentum) {

		input = new double[inputSize + 1];
		hidden = new double[hiddenSize + 1];
		output = new double[outputSize + 1];
		target = new double[outputSize + 1];

		hidDelta = new double[hiddenSize + 1];
		optDelta = new double[outputSize + 1];

		iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
		hidOptWeights = new double[hiddenSize + 1][outputSize + 1];

		random = new Random(19881211);
		randomizeWeights(iptHidWeights);
		randomizeWeights(hidOptWeights);

		iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
		hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];

		this.eta = eta;
		this.momentum = momentum;
	}

	private void randomizeWeights(double[][] matrix) {
		for (int i = 0, len = matrix.length; i != len; i++)
			for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
				double real = random.nextDouble();
				matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
			}
	}

	/**
	 * Constructor with default eta = 0.25 and momentum = 0.3.
	 * 
	 * @param inputSize
	 * @param hiddenSize
	 * @param outputSize
	 * @param epoch
	 */
	public BP(int inputSize, int hiddenSize, int outputSize) {
		this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
	}

	/**
	 * Entry method. The train data should be a one-dim vector.
	 * 
	 * @param trainData
	 * @param target
	 */
	public void train(double[] trainData, double[] target) {
		loadInput(trainData);
		loadTarget(target);
		forward();
		calculateDelta();
		adjustWeight();
	}

	/**
	 * Test the BPNN.
	 * 
	 * @param inData
	 * @return
	 */
	public double[] test(double[] inData) {
		if (inData.length != input.length - 1) {
			throw new IllegalArgumentException("Size Do Not Match.");
		}
		System.arraycopy(inData, 0, input, 1, inData.length);
		forward();
		return getNetworkOutput();
	}

	/**
	 * Return the output layer.
	 * 
	 * @return
	 */
	private double[] getNetworkOutput() {
		int len = output.length;
		double[] temp = new double[len - 1];
		for (int i = 1; i != len; i++)
			temp[i - 1] = output[i];
		return temp;
	}

	/**
	 * Load the target data.
	 * 
	 * @param arg
	 */
	private void loadTarget(double[] arg) {
		if (arg.length != target.length - 1) {
			throw new IllegalArgumentException("Size Do Not Match.");
		}
		System.arraycopy(arg, 0, target, 1, arg.length);
	}

	/**
	 * Load the training data.
	 * 
	 * @param inData
	 */
	private void loadInput(double[] inData) {
		if (inData.length != input.length - 1) {
			throw new IllegalArgumentException("Size Do Not Match.");
		}
		System.arraycopy(inData, 0, input, 1, inData.length);
	}

	/**
	 * Forward.
	 * 
	 * @param layer0
	 * @param layer1
	 * @param weight
	 */
	private void forward(double[] layer0, double[] layer1, double[][] weight) {
		// threshold unit.
		layer0[0] = 1.0;
		for (int j = 1, len = layer1.length; j != len; ++j) {
			double sum = 0;
			for (int i = 0, len2 = layer0.length; i != len2; ++i)
				sum += weight[i][j] * layer0[i];
			layer1[j] = sigmoid(sum);
		}
	}

	/**
	 * Forward.
	 */
	private void forward() {
		forward(input, hidden, iptHidWeights);
		forward(hidden, output, hidOptWeights);
	}

	/**
	 * Calculate output error.
	 */
	private void outputErr() {
		double errSum = 0;
		for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
			double o = output[idx];
			optDelta[idx] = o * (1d - o) * (target[idx] - o);
			errSum += Math.abs(optDelta[idx]);
		}
		optErrSum = errSum;
	}

	/**
	 * Calculate hidden errors.
	 */
	private void hiddenErr() {
		double errSum = 0;
		for (int j = 1, len = hidDelta.length; j != len; ++j) {
			double o = hidden[j];
			double sum = 0;
			for (int k = 1, len2 = optDelta.length; k != len2; ++k)
				sum += hidOptWeights[j][k] * optDelta[k];
			hidDelta[j] = o * (1d - o) * sum;
			errSum += Math.abs(hidDelta[j]);
		}
		hidErrSum = errSum;
	}

	/**
	 * Calculate errors of all layers.
	 */
	private void calculateDelta() {
		outputErr();
		hiddenErr();
	}

	/**
	 * Adjust the weight matrix.
	 * 
	 * @param delta
	 * @param layer
	 * @param weight
	 * @param prevWeight
	 */
	private void adjustWeight(double[] delta, double[] layer,
			double[][] weight, double[][] prevWeight) {

		layer[0] = 1;
		for (int i = 1, len = delta.length; i != len; ++i) {
			for (int j = 0, len2 = layer.length; j != len2; ++j) {
				double newVal = momentum * prevWeight[j][i] + eta * delta[i]
						* layer[j];
				weight[j][i] += newVal;
				prevWeight[j][i] = newVal;
			}
		}
	}

	/**
	 * Adjust all weight matrices.
	 */
	private void adjustWeight() {
		adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
		adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
	}

	/**
	 * Sigmoid.
	 * 
	 * @param val
	 * @return
	 */
	private double sigmoid(double val) {
		return 1d / (1d + Math.exp(-val));
	}
}

 为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

 

package ml;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Test {

	/**
	 * @param args
	 * @throws IOException
	 */
	public static void main(String[] args) throws IOException {
		BP bp = new BP(32, 15, 4);

		Random random = new Random();
		List<Integer> list = new ArrayList<Integer>();
		for (int i = 0; i != 1000; i++) {
			int value = random.nextInt();
			list.add(value);
		}

		for (int i = 0; i != 200; i++) {
			for (int value : list) {
				double[] real = new double[4];
				if (value >= 0)
					if ((value & 1) == 1)
						real[0] = 1;
					else
						real[1] = 1;
				else if ((value & 1) == 1)
					real[2] = 1;
				else
					real[3] = 1;
				double[] binary = new double[32];
				int index = 31;
				do {
					binary[index--] = (value & 1);
					value >>>= 1;
				} while (value != 0);

				bp.train(binary, real);
			}
		}

		System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");

		while (true) {
			byte[] input = new byte[10];
			System.in.read(input);
			Integer value = Integer.parseInt(new String(input).trim());
			int rawVal = value;
			double[] binary = new double[32];
			int index = 31;
			do {
				binary[index--] = (value & 1);
				value >>>= 1;
			} while (value != 0);

			double[] result = bp.test(binary);

			double max = -Integer.MIN_VALUE;
			int idx = -1;

			for (int i = 0; i != result.length; i++) {
				if (result[i] > max) {
					max = result[i];
					idx = i;
				}
			}

			switch (idx) {
			case 0:
				System.out.format("%d是一个正奇数\n", rawVal);
				break;
			case 1:
				System.out.format("%d是一个正偶数\n", rawVal);
				break;
			case 2:
				System.out.format("%d是一个负奇数\n", rawVal);
				break;
			case 3:
				System.out.format("%d是一个负偶数\n", rawVal);
				break;
			}
		}
	}

}

 运行结果截图如下:



 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

  • 大小: 15.9 KB
4
0
分享到:
评论
17 楼 zhc0822 2012-10-28  
fantaosong 写道
fantaosong 写道
zhc0822 写道
wondery 写道
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!


for里面最后那句可以删掉,这两个函数的最后那句也可以删掉,没用上。

forward中for最后的语句是不能删除的亲。outputErr和hiddenErr最后设置的值可以用于打印查看,最好保留。

我想做一个计算得分的程序:我有几个参数,但是这几个参数在得分中的权重不确定,然后我想得到这些参数,用神经网络可以实现吗?

貌似forward中for最后的语句是不能删,但是如果数比较大的情况下都变成1了,我把你的程序改了一下用到我要解决的问题上就不行了,哎,晕死了!!!

用决策树就行了。
16 楼 fantaosong 2012-10-26  
fantaosong 写道
zhc0822 写道
wondery 写道
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!


for里面最后那句可以删掉,这两个函数的最后那句也可以删掉,没用上。

forward中for最后的语句是不能删除的亲。outputErr和hiddenErr最后设置的值可以用于打印查看,最好保留。

我想做一个计算得分的程序:我有几个参数,但是这几个参数在得分中的权重不确定,然后我想得到这些参数,用神经网络可以实现吗?

貌似forward中for最后的语句是不能删,但是如果数比较大的情况下都变成1了,我把你的程序改了一下用到我要解决的问题上就不行了,哎,晕死了!!!
15 楼 fantaosong 2012-10-26  
zhc0822 写道
wondery 写道
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!


for里面最后那句可以删掉,这两个函数的最后那句也可以删掉,没用上。

forward中for最后的语句是不能删除的亲。outputErr和hiddenErr最后设置的值可以用于打印查看,最好保留。

我想做一个计算得分的程序:我有几个参数,但是这几个参数在得分中的权重不确定,然后我想得到这些参数,用神经网络可以实现吗?
14 楼 zhc0822 2012-10-23  
wondery 写道
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!


for里面最后那句可以删掉,这两个函数的最后那句也可以删掉,没用上。

forward中for最后的语句是不能删除的亲。outputErr和hiddenErr最后设置的值可以用于打印查看,最好保留。
13 楼 wondery 2012-10-23  
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!


for里面最后那句可以删掉,这两个函数的最后那句也可以删掉,没用上。
12 楼 zhc0822 2012-10-22  
fantaosong 写道
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!

建议看一下BP神经网络的原理。
11 楼 fantaosong 2012-10-22  
对算法中的两个函数不解,outputErr();hiddenErr();他们是不是做一个错误反馈处理?o * (1d - o) * (target[idx] - o),为什么这个处理要这样做?对forward中的for循环最后一句不解,为什么还要做一个处理呢?新手,望赐教!
10 楼 独爱Java 2012-05-08  
楼主有没有对Hopfield神经网络的java实现代码呢。嘿嘿
9 楼 zhc0822 2012-04-15  
yangliuy 写道
在不?你在帮女朋友做文本聚类那个作业不?特征向量矩阵的降维你用什么算法做的?选了多少个特征词?

他自己做的。
8 楼 yangliuy 2012-04-13  
在不?你在帮女朋友做文本聚类那个作业不?特征向量矩阵的降维你用什么算法做的?选了多少个特征词?
7 楼 zhc0822 2012-03-30  
yangliuy 写道
yangliuy 写道
zhc0822 写道
yangliuy 写道
zhc0822 写道
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的

呵呵,我猜也是,帮女朋友做作业,你这BF当得太称职了。你现在也在软微吗?
BTW 我女朋友也是武大的

我在武大...她在北软...你和她是同学吗?

对啊,我也在北软,你这个BP程序做newsgroup文本分类好用吗?好像准确率有点低啊

在吗?用贝叶斯和KNN都可以做到80%以上,你这个程序的分类准确率只有40%多吧,是预处理的问题还是BP本身不适合文本分类?

BP已经是一种比较古老的模型了。神经网络这十几年来发展得很一般,应用领域有限。
试试KNN和SVM。Bayes的效果这么好倒是出乎我的意料。
刚才吃午饭去了。
6 楼 yangliuy 2012-03-30  
yangliuy 写道
zhc0822 写道
yangliuy 写道
zhc0822 写道
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的

呵呵,我猜也是,帮女朋友做作业,你这BF当得太称职了。你现在也在软微吗?
BTW 我女朋友也是武大的

我在武大...她在北软...你和她是同学吗?

对啊,我也在北软,你这个BP程序做newsgroup文本分类好用吗?好像准确率有点低啊

在吗?用贝叶斯和KNN都可以做到80%以上,你这个程序的分类准确率只有40%多吧,是预处理的问题还是BP本身不适合文本分类?
5 楼 yangliuy 2012-03-30  
zhc0822 写道
yangliuy 写道
zhc0822 写道
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的

呵呵,我猜也是,帮女朋友做作业,你这BF当得太称职了。你现在也在软微吗?
BTW 我女朋友也是武大的

我在武大...她在北软...你和她是同学吗?

对啊,我也在北软,你这个BP程序做newsgroup文本分类好用吗?好像准确率有点低啊
4 楼 zhc0822 2012-03-30  
yangliuy 写道
zhc0822 写道
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的

呵呵,我猜也是,帮女朋友做作业,你这BF当得太称职了。你现在也在软微吗?
BTW 我女朋友也是武大的

我在武大...她在北软...你和她是同学吗?
3 楼 yangliuy 2012-03-30  
zhc0822 写道
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的

呵呵,我猜也是,帮女朋友做作业,你这BF当得太称职了。你现在也在软微吗?
BTW 我女朋友也是武大的
2 楼 zhc0822 2012-03-30  
yangliuy 写道
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

renaqiu.iteye是我女朋友的
1 楼 yangliuy 2012-03-30  
是RenaQiu吗?这个bp程序写的不错啊,赞,我在itye看到两个你的博客,http://fantasticinblur.iteye.com/和http://renaqiu.iteye.com/都是你的博客?

相关推荐

Global site tag (gtag.js) - Google Analytics