BP神经网络实现(Java代码)
- 格式:doc
- 大小:104.00 KB
- 文档页数:8
BP神经网络实现(Java代码)
2012-01-04 16:16 282人阅读评论(1) 收藏举报
神经网络的原理虽然理解起来不难,但是要是想实现它,还是需要做一些工作的,并且有很多细节性的东西需要注意。通过参阅各种相关资料,以及参考网络上已有的资源,自己写了一个含有一个隐含层,且只能有一个输出单元的简单的BP网络,经过测试,达到了预期的效果。
需要说明的是,神经网络的每个输入都在[0,1]中,输出也在[0,1]中,在使用神经网络解决实际问题的时候,还需要对实际问题的输入输出进行归一化处理。另外,尽量不要使得神经网络的输入或输出接近于0或1,这样会影响拟合效果。
我用正弦函数进行了一次测试,效果如图所示:
以下是相关的代码:
1.神经网络代码
[java]view plaincopy
1.package pkg1;
2.
3.import java.util.Scanner;
4.
5./*
6. *
7. */
8.public class TestNeuro {
9.
10.private int INPUT_DIM=1;
11.private int HIDDEN_DIM=20;
12.private double LEARNING_RATE=0.05;
13.double [][] input_hidden_weights=new double[INPUT_DIM][HIDDEN_DIM];
14.double [] hidden_output_weights=new double[HIDDEN_DIM];
15.double[] hidden_thresholds=new double[HIDDEN_DIM];
16.double output_threshold;
17.
18.public static void main(String[]args){
19. Scanner in=new Scanner(System.in);
20. TestNeuro neuro=new TestNeuro(1,5);
21. neuro.initialize();
22.for(int i=0;i<10000;i++){
23.double[] input=new double[1];
24. input[0]=Math.random();
25.double expectedOutput=input[0]*input[0];
26.//System.out.println("input : "+input[0]+"\t\texpectedOutput :
"+expectedOutput);
27.//System.out.println("predict before training : "+neuro.predict
(input));
28. neuro.trainOnce(input, expectedOutput);
29.//System.out.println("predict after training : "+neuro.predict(
input));
30.//in.next();
31. }
32.while(true){
33.//neuro.printLinks();
34.double[] input=new double[1];
35. input[0]=in.nextDouble();
36.double expectedOutput=in.nextDouble();
37. System.out.println("predict before training : "+neuro.predict(i
nput));
38. neuro.trainOnce(input, expectedOutput);
39. System.out.println("predict after training : "+neuro.predict(in
put));
40.
41. }
42. }
43.
44.public TestNeuro(int input_dimension,int hidden_dimension){
45.this.INPUT_DIM=input_dimension;
46.this.HIDDEN_DIM=hidden_dimension;
47.this.initialize();
48. }
49.
50.
51./**
52. * 打印出本神经元网络各层之间的连接权重,以及各个神经元上的阈值的信息。
53. */
54.void print(){
55. System.out.println("隐含层阈值:");
56.for(int i=0;i 57. System.out.print(hidden_thresholds[i]+" "); 58. }System.out.println(); 59. System.out.println("输出层阈值:"); 60. System.out.println(output_threshold); 61. 62. System.out.println("连接权重:*********************"); 63. System.out.println("输入层与隐含层的连接"); 64.for(int i=0;i 65.for(int j=0;j 66. System.out.print(input_hidden_weights[i][j]+" "); 67. }System.out.println(); 68. } 69. System.out.println("隐含层到输出层的连接"); 70.for(int i=0;i 71. System.out.print(hidden_output_weights[i]+" "); 72. }System.out.println(); 73. System.out.println("*********************************"); 74. } 75. 76./** 77. * 初始化,对所有的权值产生一个(0,1)之间的随机double型值 78. */ 79.void initialize(){ 80. 81.//输入层到隐含层的连接权重 82.for(int i=0;i 83.for(int j=0;j 84. input_hidden_weights[i][j]=Math.random(); 85. } 86. } 87.//隐含层到输出层的连接权重 88.for(int i=0;i 89. hidden_output_weights[i]=Math.random(); 90. } 91.//隐含层的阈值