`
wbj0110
  • 浏览: 1551685 次
  • 性别: Icon_minigender_1
  • 来自: 上海
文章分类
社区版块
存档分类
最新评论

Python 神经网络调教程序

阅读更多
01 import random
02 import math
03  
04 from pyneurgen.neuralnet import NeuralNet
05 from pyneurgen.nodes import BiasNode, Connection
06  
07 pop_len = 360
08 factor = 1.0 / float(pop_len)
09 population = [
10     (i, math.sin(float(i) * factor )) for i in range(pop_len)
11 ]
12  
13 all_inputs = []
14 all_targets = []
15  
16 def population_gen(population):
17     pop_sort = [item for item in population]
18     random.shuffle(pop_sort)
19     for item in pop_sort:
20         yield item
21  
22 #   Build the inputs
23 for position, target in population_gen(population):
24     pos = float(position)
25     all_inputs.append([random.random(), pos * factor])
26     all_targets.append([target])
27  
28 net = NeuralNet()
29 net.init_layers(2, [10], 1)
30 net.randomize_network()
31 net.learnrate = .20
32  
33 net.randomize_network()
34 net.set_all_inputs(all_inputs)
35 net.set_all_targets(all_targets)
36 length = len(all_inputs)
37  
38 learn_end_point = int(length * .8)
39 net.set_learn_range(0, learn_end_point)
40 net.set_test_range(learn_end_point + 1, length - 1)
41 net.layers[1].set_activation_type('tanh')
42 net.learn(epochs=125, show_epoch_results=True,random_testing=False)
43 mse = net.test()
44  
45 import matplotlib
46 from pylab import plot, legend, subplot, grid
47 from pylab import xlabel, ylabel, show, title
48  
49 test_positions = [item[0][1] * 1000.0 for item in net.get_test_data()]
50  
51 all_targets1 = [item[0][0] for item in net.test_actuals_targets]
52 allactuals = [item[1][0] for item in net.test_actuals_targets]
53  
54 #   This is quick and dirty, but it will show the results
55 subplot(3, 1, 1)
56 plot([i[1] for i in population])
57 title("Population")
58 grid(True)
59  
60 subplot(3, 1, 2)
61 plot(test_positions, all_targets1, 'bo', label='targets')
62 plot(test_positions, allactuals, 'ro', label='actuals')
63 grid(True)
64 legend(loc='lower left', numpoints=1)
65 title("Test Target Points vs Actual Points")
66  
67 subplot(3, 1, 3)
68 plot(range(1, len(net.accum_mse) + 1, 1), net.accum_mse)
69 xlabel('epochs')
70 ylabel('mean squared error')
71 grid(True)
72 title("Mean Squared Error by Epoch")
73  
74 show()
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics