`
Marshal_R
  • 浏览: 130185 次
  • 性别: Icon_minigender_1
  • 来自: 广州
社区版块
存档分类
最新评论

人工智能应用实例:数字识别(神经网络)

阅读更多

人工智能应用实例:数字识别(神经网络)

 

 

场景设置

    现提供一大堆8*8像素的图片,每张图片上面都是一个手写体的数字0-9,要求通过神经网络的方法,以其中一部分图片作为训练集生成一个数字识别的智能系统,以剩下的图片检验数字识别系统的准确性。

 

    数据集已经过预处理程序从手写体数字0-9提取位图,分为训练集datatra.txt和测试集datatest.txt,文件格式:每一行数据为一条记录,表示一张图片,共有65个属性,前64个属性为该图片的位图特征向量(每个属性取值范围都是0-16),最后一个属性为该图片上面的数字值0-9。

 

 

神经网络

    1、基本介绍

    神经网络是人工智能的一个重要分支,神经网络的模型是基于神经元的,每个神经元就是一个函数,把输入信号映射成特定的一个输出,然后这个输出信号又可以作为下一个神经元的输入信号,这样便形成一张错综复杂的神经网络。


图1  神经元

 

    举个简单的例子,假设只有一个神经元,它接受3个信号输入,实现的函数功能为对所有输入信号求和,而实际上神经元对每个输入信号有一个权重,意义在于说每个输入信号对神经元的刺激程度是不一样的,不能同等对待。假设神经元对这3个输入的权重为(1, 3, 2),那么当输入为(4, 8, 5)的时候,输出为1*4+3*8+2*5=38。

 

    一般而言,要从给定的知识库训练一个神经网络系统,神经元的函数功能是已知的,目标就是学习一组最佳的对每个输入信号的权重。

 

   2、 线性回归

    下面谈一下大家耳熟能详的线性回归,在平面上有一堆散列的点,目标是求一条最大程度拟合这些点的直线。很显然地,这也可以看作一个单神经元的神经网络模型,它有两个输入:x和1,两者权重为w1、w0,执行的函数功能为g(x)=x,因此模型表达式如下:

 

    其中W=(w0, w1)是各个输入信号的权重向量,我们的目标就是求得使直线拟合程度最大的W。按照以前数学教的方法,我们定义了平方差和为误差衡量标准,然后通过对W的各个分量求导就能算出最佳的W,回顾一下下面几条式子吧!



 

    以上是数学证明了最佳的W的求解方法,但对计算机而言,这样求解是不科学的,因为实际问题往往复杂得多,并不能通过数学演算直接求得最终解。不要忘了计算机最喜欢做大量的运算了,上面求导的思想是很有价值的,开始我们只需要任意选定一组参数作为权重向量W的值,然后通过某种算法不停更新W的值,最终W就能收敛到最佳的W,更新算法如下:


图2 权重更新算法

 

     因此,在线性回归这个场景下,我们得到

 

   3、 硬分类器(with a hard threshold)

    只要稍微修改一下线性回归的模型,我们就可以设计出一个线性分类器。


图3 线性分类器
 

    比如上图所示,可以通过一条直线把二维平面的散列点分成两类,假定左上部分为分类1,右下为分类0,我们还是采取线性回归得到的直线方程hw(x)=w1*x+w0,令g(x,y)=hw(x)-y=w1*x+w0-y,这样把分类1的点代进方程g(x,y)将得到一个>0的值,而分类0将得到<0的值,因此通过g(x0,y0)的正负便能界定点(x0,y0)属于分类1还是分类0。还记得线性回归神经元的输入为X=(x,1),权重向量为W=(w1,w0),函数功能为g(x)=x吧,在这里我们要对神经元做一下修改,输入为X=(x,1,y),权重向量为W=(w1,w0,-1),模型表达式如下


 

    显然,hw(X0)=1表示X0对应的点属于分类1,hw(X0)=0表示X0对应的点不属于分类1,因此,模型的函数功能相当于判断给定的一个输入是否属于某一个分类。这里尽管模型的hw(x)表达式发生了变化,但是图2的权重更新算法仍然是适用的,W的每个分量更新表达式如下

 

    注意:式子中的y并不是点的y坐标,这里y的取值只有1和0,对每条记录都更新W一次,1表示该条记录属于给定的分类,0则不属于。

 

    4、软分类器(with a soft threshold)

    上面的分类器在人工智能里面被认为是with a hard threshold的,直译是硬门槛,其实就是hw(X)要么是1要么是0,要么属于这个分类要么不属于这个分类的意思,这么绝对的判断是不好的,应该避免,人工智能里面最常见的做法便是把它转换为概率,即属于这个分类的概率。

 

    把一个无穷范围的实数转换为0-1之间的小数的做法是

 

    因此,在硬分类器的基础上,再做进一步修改,最终得到的软分类器模型如下

 

    与之对应的权重更新表达式调整为

 

数字识别

    介绍完神经网络大概的情况,终于要步入正题了,现在让我们来看一下数字识别与神经网络的联系。

 

    最容易而实际上也是不可行的想法是,我们的数字识别系统就是一个单神经元的神经网络模型,它有64个输入,输出为0-9之间的整数,表示一条记录属于哪一个数字分类。然而,正如我前边提到的,这么绝对的判断是不科学的,应该基于概率进行分类。

 

    因此,数字识别的神经网络应该由10个独立的神经元组成,它们具有相同的64个信号输入,而输出则各不相同,神经元分别编号0-9,编号i的神经元的输出表示该条记录属于数字分类i的概率,即图片上的数字为i的概率。于是,把同一张图片输入到这10个神经元,我们有理由相信,使得输出值最大的神经元的编号就是这张图片上面的数字值。

 

    我们仍然沿用上面的软分类器模型和权重更新表达式,对每条记录(即一张图片)更新一次W的值,其中权重更新表达式中的y表示该条记录属于给定分类的概率,因为每条记录的分类已经是确定的,要么属于该给定分类,要么不属于,因此y的取值只有1.0和0.0。另外值得注意的是,hw(x)表达式中,W*X的绝对值是一个不小的数,容易使hw(X)只取到接近0.0和1.0的值,为了避免这个问题,可以让W*X再除以一个固定的因子,比如360。

 

 

代码实现

/*
 * =========================================================================
 *
 *       Filename:  main.c
 *
 *    Description:  Neural Network
 *
 *        Version:  1.0
 *        Created:  2014年12月20日 21时28分12秒
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  阮仕海
 *   Organization:  AI 选修班第8组
 *
 * =========================================================================
 */

#define ALPHA 1.0
#define FACTOR 360
#define ATTR_COUNT 64
#define TIMES 36
#define FILE_TRA "digitstra.txt"
#define FILE_TEST "digitstest.txt"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

int main() {
	FILE *fp_tra, *fp_test;
	int times;
	int digit, x[ATTR_COUNT]; // 每组数据对应的数字值以及位图属性数组
	int tra_size = 0, test_size = 0, hit_size = 0;
	double y[10]; // 数字类(0-9)数组,y[i]=1.0当且仅当i==digit,否则y[i]=0.0
	double weight[10][ATTR_COUNT] = {0.0}; // 权重数组(对应数字分类0-9)
	char line[256]; // 缓冲区读取文件一行
	const char delim[] = ","; // 文件数据分隔符

	// 打开训练数据文件
	if ((fp_tra = fopen(FILE_TRA, "r")) == NULL)
		exit(1);
	times = TIMES;
	// 重复读训练数据文件,对权重进行times次迭代更新
	while (times--) {
		// 文件指针指向文件开始处
		rewind(fp_tra);
		fgets(line, 255, fp_tra);
		while (!feof(fp_tra)) {
			int i;

			if (times == 0)
				tra_size++;
			// 将文件一行处理为一组训练数据
			x[0] = atoi(strtok(line, delim));
			for (i=1; i<ATTR_COUNT; i++) {
				x[i] = atoi(strtok(NULL, delim));
			}
			digit = atoi(strtok(NULL, delim));
			for (i=0; i<10; i++) {
				y[i] = 0.0;
			}
			y[digit] = 1.0;

			double hw_x, wx;
			int dig;
			// 分别对数字分组0-9更新权重
			for (dig=0; dig<10; dig++) {
				// 计算hw(x)=1/(1+e^(-W*X))
				wx = 0.0;
				for (i=0; i<ATTR_COUNT; i++) {
					wx += weight[dig][i]*x[i];
				}
				wx /= FACTOR;
				hw_x = (double)1/(1+exp(-wx));

				// 更新权重wi=wi+alpha*(y-hw(x))*hw(x)*(1-hw(x))*xi
				for (i=0; i<ATTR_COUNT; i++) {
					weight[dig][i] += ALPHA*(y[dig]-hw_x)*hw_x*(1-hw_x)*x[i];
				}
			}

			// 读取文件下一行
			fgets(line, 255, fp_tra);
		}
	}
	fclose(fp_tra);

	/*
	// 打印权重(代码测试用)
	int j, k;
	for (j=0; j<10; j++) {
		printf("For digit %d\n: ", j);
		for (k=0; k<63; k++) {
			printf("%.1lf, ", weight[j][k]);
		}
		printf("%.1lf\n\n", weight[j][k]);
	}
	*/

	// 打开测试数据文件
	if ((fp_test = fopen(FILE_TEST, "r")) == NULL)
		exit(1);
	fgets(line, 255, fp_test);
	while (!feof(fp_test)) {
		int i;

		test_size++;
		// 将文件一行处理为一组测试数据
		x[0] = atoi(strtok(line, delim));
		for (i=1; i<ATTR_COUNT; i++) {
			x[i] = atoi(strtok(NULL, delim));
		}
		digit = atoi(strtok(NULL, delim));

		double hw_x, wx;
		int dig, target;
		double max_posibility = -1.0;
		// 分别对数字分组0-9计算概率
		for (dig=0; dig<10; dig++) {
			// 计算概率hw(x)=1/(1+e^(-W*X))
			wx = 0.0;
			for (i=0; i<ATTR_COUNT; i++) {
				wx += weight[dig][i]*x[i];
			}
			wx /= FACTOR;
			hw_x = (double)1/(1+exp(-wx));

			// 更新目标数字值
			if (hw_x > max_posibility) {
				target = dig;
				max_posibility = hw_x;
			}
		}

		// 数字识别正确
		if (target == digit)
			hit_size++;

		// 读取文件下一行
		fgets(line, 255, fp_tra);
	}
	fclose(fp_test);

	// 打印结果
	printf("Training data size: %d\n", tra_size);
	printf("    Test data size: %d\n", test_size);
	printf("          Hit size: %d\n", hit_size);
	printf("         Hit ratio: %.2lf%%\n", (double)hit_size/test_size*100);

	return 0;
}

 

     输出结果如下:
   
 

 

 

  • 大小: 53.5 KB
  • 大小: 3 KB
  • 大小: 8.7 KB
  • 大小: 4.8 KB
  • 大小: 16.2 KB
  • 大小: 10.9 KB
  • 大小: 18.5 KB
  • 大小: 9 KB
  • 大小: 6.1 KB
  • 大小: 5.8 KB
  • 大小: 4.2 KB
  • 大小: 4.8 KB
  • 大小: 4.1 KB
  • 大小: 5.2 KB
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics