机器学习之 神经网络的实现(二)-->手写识别
2017-04-11 11:11
656 查看
上一篇文章介绍了神经网络的一个简单的案例,帮助大家了解它的基本概念,神经网络的一个重要的优点就是它能够提取目标事物的特性,然后把这些事物全部用加权的方法来计算,并且得出结果,这种方式除了单纯模拟了人的脑神经工作方式以外,实际上还模拟了很多现实生活中的规则,(现实生活中几乎所有的指标都是受很多个因素影响,并且动态的来回波动的,神经网络很适合找出这其中的关系,)
这里我们来继续介绍神经网络的那些事。
这里我们来上手一个小的demo——手写识别,之前如果看过我的另一篇关于KNN算法下的手写识别演示,这里理解起来可能会容易一些:
话不多说,先上一段手写识别的说明,在KNN中我们借助二维的bool数组来充当模板(true表示笔划过的地方,false表示空白),这里我们也可以沿用这个方法,另外还有很多辅助的方法,
1.检查笔画断点(我们一次落笔,可以写多个笔画,因为连笔,但是一个笔画一定是一次落笔完成的,所以可以检查一个笔画中有没有断开的部分,比如说电脑判断“再”和“雨”)、
2.多次移动模板(按九宫格的排布移动比对9次)来确定精确度、
3.笔画数量的检查(这个可以借助鼠标监听器的按下和释放两个动作来实现,比如电脑判断“门”和“闩”)、
4.检查笔画的顺序(手写文字都有固定的笔画顺序,因此这也可以作为一种特点,被神经网络提取用来校准文字的识别效果和精确度,比如说让电脑来判断“木”和“水”)、
5.借助计算笔划的余弦值来记录笔画的走向和大概形状(这是一种区别于KNN的另一种实现方法,有兴趣的朋友可以自己编写代码验证一下效果啊),
![](https://img-blog.csdn.net/20170410220639463?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvUGFpX2RheGluZw==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center)
此外还有很多其他的辅助方法,这里就不再列举(其实很多精炼的方法,我也还不知道,有哪个朋友精通这个的话,欢迎留言讨论,,)
这里说完神经网络实现手写识别的基本思路,我们来看看具体实现,
方法太多,所以我就不一一列举实现代码了,这里仅仅以模板比对的方法为例来给出一个实现方法和相应的java代码。
上代码:(这里主要实现了对数字的识别,至于模板,大家可以自行设置,这里不做限制,)
以上代码以识别数字为例。(识别字母、或者汉字的方法同理。)
这里先解释一下,这里的各层节点的用处,
第二层节点:用来捕捉用户书写的图形的各种特点,当捕捉到相应的特点以后,会向第三层节点发出一个0~1之间的信号值(图形和相应的节点所示的特点吻合的越好,那么返回的这个值就越接近1)。比如说“0”这个图形,我们可把它分成很多个部分,(最简单的一种分法,从中间劈开,分成左半部分和右半部分,),其他字符也同理。
![](https://img-blog.csdn.net/20170411110751381?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvUGFpX2RheGluZw==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center)
第三层节点:表示一个固定的字符,用来把第二层节点传入的值进行加权累加(相应的意义在于把各个特定按照不同的重要度(即权值),加在一起,所加得到的结果越高,就说明用户的所写的这个图形符号和我们第三层相应节点的拟合度越高。),然后第三层节点会把这个加权累加的结果传到第四层节点处。
第四层节点:表示神经网络对图形的分析结果,第四层节点其实只需要定义一个节点就可以满足需要,这个节点将第三层节点传回的值进行排序(第n个节点传回的值越大,就说明这个图形和我们第三层的第n个节点对应的值越相近。所以从大到小排列这些结果,我们就得到了用户输入的图形最有可能表示的字符,第二有可能表示的字符,第三有可能表示的字符······)。
值得说明的是:因为我们这个程序的参数较多(各层自由参数一共有大约一千个左右。),所以我们在调试的时候可以尝试创建一个窗口,参数的值按照某一个顺序,排列出来,这样有助于我们找到出错的地方。
这里我们来继续介绍神经网络的那些事。
这里我们来上手一个小的demo——手写识别,之前如果看过我的另一篇关于KNN算法下的手写识别演示,这里理解起来可能会容易一些:
话不多说,先上一段手写识别的说明,在KNN中我们借助二维的bool数组来充当模板(true表示笔划过的地方,false表示空白),这里我们也可以沿用这个方法,另外还有很多辅助的方法,
1.检查笔画断点(我们一次落笔,可以写多个笔画,因为连笔,但是一个笔画一定是一次落笔完成的,所以可以检查一个笔画中有没有断开的部分,比如说电脑判断“再”和“雨”)、
2.多次移动模板(按九宫格的排布移动比对9次)来确定精确度、
3.笔画数量的检查(这个可以借助鼠标监听器的按下和释放两个动作来实现,比如电脑判断“门”和“闩”)、
4.检查笔画的顺序(手写文字都有固定的笔画顺序,因此这也可以作为一种特点,被神经网络提取用来校准文字的识别效果和精确度,比如说让电脑来判断“木”和“水”)、
5.借助计算笔划的余弦值来记录笔画的走向和大概形状(这是一种区别于KNN的另一种实现方法,有兴趣的朋友可以自己编写代码验证一下效果啊),
此外还有很多其他的辅助方法,这里就不再列举(其实很多精炼的方法,我也还不知道,有哪个朋友精通这个的话,欢迎留言讨论,,)
这里说完神经网络实现手写识别的基本思路,我们来看看具体实现,
方法太多,所以我就不一一列举实现代码了,这里仅仅以模板比对的方法为例来给出一个实现方法和相应的java代码。
上代码:(这里主要实现了对数字的识别,至于模板,大家可以自行设置,这里不做限制,)
package the_main; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.Graphics; import java.awt.Label; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import javax.swing.JButton; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JOptionPane; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.JTextArea; import javax.swing.JTextField; import javax.swing.plaf.metal.MetalBorders.OptionDialogBorder; import 机器学习_手写识别4_原版.手写识别.test1; public class Begin { // 学习过程和测试过程要分开做: // 作者 李梦旭,如需转载,请注明, // 借助神经网络实现手写识别数字: // 对神经网络的新的理解:每一个节点都是一个表达式,而一个节点中的表达上可以有多个参数。 // 基本思路: 任何一个数字都可以被分解成几个固定的组成部分:比如说“3”可以被分成两个半圆,“6”可以被分成一个半圆和一个圆的组合, // 这里的神经网络设计思路是:第一层神经节点用来采集像素点,例如12*12的一张画板上,就是144个像素点,所以第一层有144个神经节点, // 第二层神经节点用来识别各种不同的笔画(包括笔画的位置),比如上文中所说的半圆,或者四分之一圆,或者一竖, // 第三层神经网络用来把第二层的识别结果拼合在一起,形成一个真正的字,这里注意:这里的字必须是事先定义出来,电脑已经知道的,否则电脑会认为这个字不存在, // 这里按识别圆润型数字为例子:中间层的节点对应关系: // 1.“8”的左上角(右倾),2.“8”的右上角(左倾),3.“8”的左下角,4.“8”的右下角(左倾),5.“1”的整个部分, // 6.“2”的下边的横线,7.“2”上面的部分,8.“4”中间的一横,9.“4”左上角的一斜杠,10.“5”上面的一横,11"5"左边的一竖, // 12.“7”下面的左斜线,13.“0”的左半部分,14.“0”的右半部分, // “1”----5 // “2”----6,7 // “3”----2,4 // “4”----8,9 或者8,11 // “5”----11,10,4 // “6”----13,4 // “7”----10,12 // “8”----1,2,3,4, // “9”----14,1 // FIXME 这里还可以加上图形的相连的检查节点,比如“2”的上半部分和下半部分之间有没有相连。没有的话,这个图形很可能就不是“2” // 定义输入层的144个节点: public static boolean[][] input = new boolean[32][32]; // public void set_input(boolean[][] input){ // this.input=input; // } // 中间层节点的参数队列 static ArrayList list_sec_node = new ArrayList<>(); // 第三层节点的参数队列; static ArrayList list_thr_node = new ArrayList<>(); static forth_node fo = new forth_node(); public Begin() { super(); // 定义中间层每一个节点最多有144个参数; // 0.“8”的左上角(右倾),1.“8”的右上角(左倾),2.“8”的左下角,3.“8”的右下角(左倾),4.“1”的整个部分, // 5.“2”的下边的横线,6.“2”上面的部分,7.“4”中间的一横,8.“4”左上角的一斜杠,9.“5”上面的一横,10"5"左边的一竖, // 11.“7”下面的左斜线,12.“0”的左半部分,13.“0”的右半部分, // 0 list_sec_node.add(new second_node(0, 16, 20, 31)); // 0>7,5>11 // 1 list_sec_node.add(new second_node(12, 16, 32, 31)); // // // 2 list_sec_node.add(new second_node(0, 0, 16, 16)); // // 3 list_sec_node.add(new second_node(16, 0, 32, 16)); // // 4 list_sec_node.add(new second_node(11, 3, 21, 30)); // 5>8,0>11 // // 5 list_sec_node.add(new second_node(2, 2, 30, 9)); // 1>10,0>3 // // 6 list_sec_node.add(new second_node(4, 4, 30, 30)); // 2>10,3>11 // 7 list_sec_node.add(new second_node(2, 12, 30, 20)); // 8 list_sec_node.add(new second_node(0, 16, 16, 31)); // 9 list_sec_node.add(new second_node(6, 25, 28, 31)); // 10 list_sec_node.add(new second_node(0, 16, 16, 31)); // 11 list_sec_node.add(new second_node(4, 4, 28, 25)); // 12 list_sec_node.add(new second_node(2, 2, 16, 30)); // 13 list_sec_node.add(new second_node(16, 2, 30, 30)); // 第三层节点 list_thr_node.add(new third_no 4000 de(new int[] { 12, 13 }, 0));// 表示数字“0” list_thr_node.add(new third_node(new int[] { 4 }, 1));// 表示数字“1” list_thr_node.add(new third_node(new int[] { 5, 6 }, 2)); list_thr_node.add(new third_node(new int[] { 2, 3 }, 3)); list_thr_node.add(new third_node(new int[] { 1, 4, 7, 8, 10 }, 4));// 表示数字“4”这里8和10任选一个, list_thr_node.add(new third_node(new int[] { 3, 9, 10 }, 5)); list_thr_node.add(new third_node(new int[] { 3, 12 }, 6));// 表示数字“6” list_thr_node.add(new third_node(new int[] { 9, 11 }, 7)); list_thr_node.add(new third_node(new int[] { 0, 1, 2, 3 }, 8));// 表示数字“8” list_thr_node.add(new third_node(new int[] { 0, 13 }, 9)); } double a = 0.01; // 这个函数是公用的误差函数,用来计算和真实结果的相差的值,(第一个参数表示第二层的第几个节点,m表示当前的图形是哪个数字) void lose(int m) { // 将y值和真实值比较,确定误差,并且反过来,调参数。 // 取出第三层节点的相关二层节点队列,只让这些节点学习这张图像; ArrayList seno = list_thr_node.get(m).list_se_tnode; for (int i = 0; i < seno.size(); i++) { seno.get(i).result_true = 1; double par[][] = seno.get(i).parameter; second_node se_nod = seno.get(i); // 计算当前的结果; se_nod.getresult(input); for (int k2 = se_nod.y2 - 1; k2 >= (se_nod.y1); k2--) { for (int k = se_nod.x1; k < (se_nod.x2); k++) { if (input[k][k2]) { // 如果识别成是一样的了(1),但本来是不一样的(0),说明有墨迹的像素点上的权值给的太高了,其余点的负权值给的太低了, par[k - se_nod.x1][k2 - se_nod.y1] -= a * (se_nod.result - se_nod.result_true); } else { par[k - se_nod.x1][k2 - se_nod.y1] += a * 0.2 * (se_nod.result - se_nod.result_true); }// 另外,如果识别是不一样的(0),但是实际上是一样的(1),说明有墨迹的像素点上的权值给的太低了,而其余点上给的权值太高了。 // 如果识别正确,则调整的量会很小 // 注意 任何一个点的权值都应在-1到+1之间, if (par[k - se_nod.x1][k2 - se_nod.y1] < -0.99) { par[k - se_nod.x1][k2 - se_nod.y1] = -0.95; } else if (par[k - se_nod.x1][k2 - se_nod.y1] > 0.99) { par[k - se_nod.x1][k2 - se_nod.y1] = 0.95; } } } } } void learning() { // 读出文件中的数字模型 for (int i1 = 0; i1 < 300; i1++) { for (int m = 0; m < 10; m++) { for (int n = 0; n < 10; n++) { try { // "m"表示当前的这个图形表示的数字 BufferedReader br = new BufferedReader(new FileReader( "model/" + m + "." + n + ".mode")); String s = ""; int i = 0; while ((s = br.readLine()) != null) { for (int j = 0; j < 32; j++) if (Integer.parseInt(s.charAt((j + 1) * 2 - 1) + "") == 1) { input[j][31 - i] = true; } else if (Integer.parseInt(s .charAt((j + 1) * 2 - 1) + "") == 0) { input[j][31 - i] = false; } else { System.err.println("文件识别出现错误"); } i++; } br.close(); } catch (IOException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } // 开始训练: // 计算中间层的输出值y (y是一个-1到1之间的数), for (int j = 0; j < 2; j++) { lose(m); for (int k = 0; k < list_thr_node.size(); k++) { list_thr_node.get(k).learn(m); } } } } } } public static void check() { // 完成地二层节点的计算; for (int i = 0; i < list_sec_node.size(); i++) { list_sec_node.get(i).getresult(input); } // 完成第三层节点的计算 for (int i = 0; i < list_thr_node.size(); i++) { list_thr_node.get(i).getresult(); fo.third_price[i] = list_thr_node.get(i).price; System.out.println(fo.third_price[i]); } fo.getresult(); JOptionPane.showMessageDialog(null, Arrays.toString(fo.third_list)); } void showui() { final JFrame jf = Create_allframe_.create_jframe("绘图", 1000, 750, 0, 0, false, 3, true); JButton jb = new JButton("绘图"); jf.add(jb); JButton jb_small = new JButton("绘小图"); jf.add(jb_small); JButton jb_wei = new JButton("权重分布矩阵"); jf.add(jb_wei); JButton jb_wei_draw = new JButton("权重图像"); jf.add(jb_wei_draw); JButton jb_third_par = new JButton("第三层节点参数"); jf.add(jb_third_par); JPanel jp_wei_draw;// 这个面板用来绘制彩色图像。 final JTextField jtf = new JTextField(15); jf.add(jtf); JButton jb_show = new JButton("测试学习成果"); jf.add(jb_show); // 这个是保留的小数的位数 final int par_size = 5; final JTextArea jta = new JTextArea(32, 16 * par_size); jta.setBackground(Color.YELLOW); final JPanel jp = Create_allframe_.create_jpanel(jf, jf.getWidth() + 300, jf.getHeight() - 60);// 这个面板是公用的面板,用来绘制jb,jb_small,jb_wei, JScrollPane jsp = Create_allframe_.create_jscrollpane(jp, jf.getWidth() - 20, jf.getHeight() - 80);// 一个带滑轮的滚动面板; jf.add(jsp); jf.setVisible(false); jf.setVisible(true); jb.addActionListener(new ActionListener() { // 这个函数用于绘制图形模板里面的内容 public void actionPerformed(ActionEvent e) { // 绘制模板阵列: jp.removeAll(); jp.add(jta); jta.setText("这是一个“" + "0" + "”模板\n"); for (int i = 0; i < input.length; i++) { for (int j = 0; j < input[1].length; j++) { if (input[j][31 - i]) { // 先输出第三十一行 jta.append("1"); } else { jta.append("0"); } } jta.append("\n"); } } }); jb_small.addActionListener(new ActionListener() { // 这个函数用来绘制模板的一部分 public void actionPerformed(ActionEvent e) { jp.removeAll(); jp.add(jta); jp.setVisible(false); jp.setVisible(true); // 可修改参数 int x1 = 0, y1 = 16, x2 = 31, y2 = 31; jta.setText("(" + x1 + "," + y1 + "),(" + x2 + "," + y2 + ")\n"); for (int i = y2; i >= y1; i--) { for (int j = x1; j < x2; j++) { if (input[j][i]) { jta.append("1"); } else { jta.append("0"); } } jta.append("\n"); } } }); jb_wei.addActionListener(new ActionListener() { // 这个函数用来展示参数的情况 public void actionPerformed(ActionEvent e) { // 这里定义一个绘制权值图的方法 String str = jtf.getText(); // 获取输入栏中的数字 int sec_num = 0; jp.setVisible(false); jp.setVisible(true); try { sec_num = Integer.parseInt(str); } catch (Exception e2) { JOptionPane.showMessageDialog(jf, "请在输入栏中输入数字"); } jp.removeAll(); jp.add(jta); jp.setVisible(false); jp.setVisible(true); jta.setText("这是第" + sec_num + "个节点的权重\n 识别结果为result=" + list_sec_node.get(sec_num).result + "\n"); double[][] par = list_sec_node.get(sec_num).parameter; int x1 = list_sec_node.get(sec_num).x1; int y1 = list_sec_node.get(sec_num).y1; int x2 = list_sec_node.get(sec_num).x2; int y2 = list_sec_node.get(sec_num).y2; for (int j = input[1].length - 1; j >= 0; j--) { for (int i = 0; i < input.length; i++) { // i表示横坐标,j表示纵坐标 // 没有内容的地方补零; if (i < x1 || i >= x2 || j < y1 || j >= y2) { jta.append(((0 + ". ").substring(0, par_size) + " ")); } else { if (par[i - x1][j - y1] > 0) { jta.append(((par[i - x1][j - y1] + " ") .substring(0, par_size) + " ")); } else { jta.append(((par[i - x1][j - y1] + " ") .substring(0, par_size) + " ")); } } } jta.append("\n"); } } }); final JLabel jl_par = new JLabel(); jb_wei_draw.addActionListener(new ActionListener() { // 这个函数用来绘图; public void actionPerformed(ActionEvent e) { // 注意绘图是直接在面板上绘制 // 这里定义一个绘制权值图的方法 String str = jtf.getText(); // 获取输入栏中的数字 int sec_num = 0; try { sec_num = Integer.parseInt(str); } catch (Exception e2) { JOptionPane.showMessageDialog(jf, "请在输入栏中输入数字"); } Graphics g = jp.getGraphics(); jp.removeAll(); jp.update(g); jl_par.setText("这是第" + sec_num + "个节点的权重\n"); // 开始绘制 double[][] par = list_sec_node.get(sec_num).parameter; int x1 = list_sec_node.get(sec_num).x1; int y1 = list_sec_node.get(sec_num).y1; int x2 = list_sec_node.get(sec_num).x2; int y2 = list_sec_node.get(sec_num).y2; for (int j = input[1].length - 1; j >= 0; j--) { for (int i = 0; i < input.length; i++) { // i表示横坐标,j表示纵坐标 // 没有内容的地方补零; int co_base = 124; int co_red_blue; int jp_heigh = jp.getHeight() - 60; if (i < x1 || i >= x2 || j < y1 || j >= y2) { g.setColor(new Color(co_base, co_base + 80, co_base)); g.fillOval(i * 20 + 20, jp_heigh - j * 20, 15, 15); } else { if (par[i - x1][j - y1] > 0) { co_red_blue = (int) (par[i - x1][j - y1] * 120); g.setColor(new Color(co_base + co_red_blue, co_base, co_base)); g.fillOval(i * 20 + 20, jp_heigh - j * 20, 15, 15); } else { co_red_blue = Math.abs((int) (par[i - x1][j - y1] * 120)); g.setColor(new Color(co_base, co_base, co_base + co_red_blue)); g.fillOval(i * 20 + 20, jp_heigh - j * 20, 15, 15); } } } } } }); jb_third_par.addActionListener(new ActionListener() { // 这个函数用来展示第三层节点的参数内容; public void actionPerformed(ActionEvent e) { String str = jtf.getText(); // 获取输入栏中的数字 int sec_num = 0; try { sec_num = Integer.parseInt(str); } catch (Exception e2) { JOptionPane.showMessageDialog(jf, "请在输入栏中输入数字"); } Graphics g = jp.getGraphics(); jp.removeAll(); jp.add(jta); jp.update(g); jl_par.setText("这是第" + sec_num + "个节点的权重\n"); // 开始展示参数 jta.setText("这是第" + sec_num + "个节点的权重\n"); third_node node = list_thr_node.get(sec_num); for (int j = 0; j < node.parameter.length; j++) { jta.append(((node.parameter[j] + " ").substring(0, par_size) + " ")); } } }); jb_show.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { // 这里显示一个新窗口,用来给用户书写; test1 t = new test1(); t.showUI(); } }); } public static void main(String[] args) { Begin the = new Begin(); the.learning(); the.showui(); } } package the_main; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Label; import java.awt.Toolkit; import javax.swing.JDialog; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JScrollPane; import javax.swing.WindowConstants; public class Create_allframe_ { // 注意添加函数是有两个参数 // 即frame.add(create_jscrollpane(), // BorderLayout.CENTER); static JScrollPane create_jscrollpane(JPanel jpanel, int jscrollpane_width, int jscrollpane_height) { // 注意,如果要向jscrollpane中添加组件,必须直接添加到jpanel中,不要用jscrollpane的.add()函数,没有用的 // 如果想同时返回多个属性,可以定义一个类,类中包含你所需要返回的所有属性,然后返回这个类的一个对象 Dimension screensize = Toolkit.getDefaultToolkit().getScreenSize();// 获取屏幕的宽度 if (jpanel == null) {// 如果不输入窗体大小,则显示默认大小 jpanel=new JPanel(); int jpanel_width = (int) (jscrollpane_width * 2); int jpanel_height = (int) (jscrollpane_height * 2); jpanel.setPreferredSize(new Dimension(jpanel_width, jpanel_height)); FlowLayout flowLayout = new FlowLayout(); jpanel.setLayout(flowLayout); } JScrollPane jScrollPane = new JScrollPane(jpanel); if (jscrollpane_width == -1) {// 如果不输入窗体大小,则显示默认大小 jscrollpane_width = (int) (screensize.getWidth() / 2 - 20); } if (jscrollpane_height == -1) { jscrollpane_height = (int) (screensize.getHeight() / 2 - 40); } jScrollPane.setPreferredSize(new Dimension(jscrollpane_width, jscrollpane_height)); return jScrollPane; } static JFrame create_jframe(String title, int width, int height, int setLocationx, int setLocationy, boolean resizeable, int exit_style, boolean setVisible) { JFrame jframe = new JFrame(); if (!title.equals("")) { jframe.setTitle(title); } jframe.setSize(width, height); Dimension screensize = Toolkit.getDefaultToolkit().getScreenSize();// 获取屏幕的宽度 if (width == -1) {// 如果不输入窗体大小,则显示默认大小 width = (int) (screensize.getWidth() / 2); } if (height == -1) { height = (int) (screensize.getHeight() / 2); } jframe.setSize(width, height); if (setLocationx == -1) {// 如果不输入位置,则显示默认位置 setLocationx = (int) (screensize.getWidth() / 4); } if (setLocationy == -1) { setLocationy = (int) (screensize.getHeight() / 4); } jframe.setLocation(setLocationx, setLocationy); FlowLayout flowlayout = new FlowLayout(); jframe.setLayout(flowlayout); jframe.setDefaultCloseOperation(exit_style); jframe.setResizable(resizeable); jframe.setVisible(setVisible); return jframe; } static JFrame create_jframe_simplify(String title, int exit_style, boolean setVisible) { JFrame jframe = new JFrame(); if (!title.equals("")) { jframe.setTitle(title); } Dimension screensize = Toolkit.getDefaultToolkit().getScreenSize();// 获取屏幕的长和宽 int height = (int) (screensize.getHeight() / 2); int width = (int) (screensize.getWidth() / 2); jframe.setSize(width, height); int setLocationx = (int) (screensize.getWidth() / 4); int setLocationy = (int) (screensize.getHeight() / 4); jframe.setLocation(setLocationx, setLocationy); FlowLayout flowlayout = new FlowLayout(); jframe.setLayout(flowlayout); jframe.setDefaultCloseOperation(exit_style); jframe.setResizable(false); jframe.setVisible(setVisible); return jframe; } static JPanel create_jpanel(JFrame jframe, int width, int height) { JPanel jpanel = new JPanel(); if (width == -1) { width = jframe.getWidth() - 20; } if (height == -1) { height = jframe.getHeight() - 20; } jpanel.setPreferredSize(new Dimension(width, height)); FlowLayout flowlayout = new FlowLayout(); jpanel.setLayout(flowlayout); jpanel.setVisible(true); jframe.add(jpanel); jpanel.setBackground(Color.WHITE); return jpanel; } static JDialog creat_jdialog(JFrame jframe, String title, int width, int height, String content, boolean setVisible) { jframe.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); JDialog jdialog = new JDialog(jframe, title); jdialog.setSize(width, height); FlowLayout flowlayout = new FlowLayout(); jdialog.setLayout(flowlayout); jdialog.setDefaultCloseOperation(2); jdialog.setModal(true); if (width == -1) { width = 300; } if (height == -1) { height = 200; } jdialog.setSize(width, height); jdialog.setLocation(500, 150); Label l1 = new Label(content); jdialog.add(l1); jdialog.setVisible(setVisible); return jdialog; } static JDialog creat_jdialog_simpfily(JFrame jframe, String content) { jframe.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); JDialog jdialog = new JDialog(jframe, "提示"); jdialog.setDefaultCloseOperation(2); jdialog.setModal(true); jdialog.setSize(300, 200); FlowLayo c34a ut flowlayout = new FlowLayout(); jdialog.setLayout(flowlayout); jdialog.setLocation(500, 150); Label l1 = new Label(content); jdialog.add(l1); jdialog.setVisible(true); return jdialog; } public static void main(String[] args) { JFrame jFrame = create_jframe_simplify("123", 3, true); JPanel jp = new JPanel(); jp.setPreferredSize(new Dimension(500,500)); JScrollPane jScro = create_jscrollpane(jp, -1, -1); jFrame.add(jScro); // JPanel jPanel = create_jpanel(jFrame, 300, 400); // jPanel.setBackground(Color.black); } } package the_main; import java.util.Arrays; public class forth_node { double[] third_price = new double[10]; // {9.0,8.0,7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0}; int[] third_list = new int[10]; int getmax_index() { int index = 0; for (int i = 0; i < third_price.length; i++) { if (third_price[i] > third_price[index]) { index = i; } } third_price[index] = -Double.MAX_VALUE; return index; } int[] getresult() { for (int i = 0; i < third_price.length; i++) { third_list[i] = getmax_index(); } return third_list; } // int n = 0; // // void set_Value(double d) { // third_price = d; // n++; // if (n == 10) { // n = 0; // } // } // public static void main(String[] args) { // forth_node fo = new forth_node(); // System.out.println((Arrays.toString(fo.getresult()))); // // } } package the_main; import org.junit.Test; public class second_node {// 数组下标方法:数组下标=(像素点坐标-x1,像素点坐标-y1) // 这个数组是第二层节点的参数 double[][] parameter; int x1, y1, x2, y2; public second_node(int x1, int y1, int x2, int y2) { this.x1 = x1; this.y1 = y1; this.x2 = x2; this.y2 = y2; parameter = new double[x2 - x1][y2 - y1]; } // 这里暂定为把节点的计算权重累加结果的值,都放在这个节点类中,把计算结果的值也放入这个对象中(对于每一个样本,这个值都会完全不同,), double price=0; // 这个函数用来计算输入的图形和第一种情况的相似度 double price(boolean[][] input) { int num=0; for (int i = x1; i < x2; i++) { for (int j = y1; j < y2; j++) { if (input[i][j]) {// 如果当前像素点上被笔画过,则加上权值 num++; price += parameter[i-x1][j-y1]; } else {// 否则就减去权值,(这里权值可以为负值) price -= parameter[i-x1][j-y1]; } } } price=price/num; // System.out.println(price); return price; } double result; // 这个阈值函数还可以用其他公式来实现,比如说Sigmoid函数 // 这个是阈值函数,用来在第三层节点处衡量累加前的每一个值是否达到标准,返回的是一个-1到1之间的数,表示相似度 double result(double price) { result = (((Math.atan(price)) / 1.58)) ; // System.out.println(result); return result; } double getresult(boolean[][] input) { price(input); result(price); return result; } double result_true; } package the_main; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; public class third_node { // 这是第三层节点,这层节点把第二层节点 // 的结果有序拼接在一起(方法是加权相加),但是这一层的每一个节点,都相应第二层所有的节点,但是每一个节点的权值都不同(有正有负), // 定义与此节点相关的第二层节点的队列; ArrayList list_se_tnode = new ArrayList<>(); double[] parameter;// 参数的个数和 与第二层相关节点的个数相同,(一个节点配一个参数) int m_expect;// 这里的m 指的是数字的预测值,是一个常数 double divide;// 这里规定一个阈值 // 这个函数用来构造第三层节点,注意参数中的整形数组里记录的是这个节点所响应的第二层每个节点的序号, public third_node(int[] index, int m_expect) { this.m_expect = m_expect; // 把和这个节点有关的第二层几点都添加进队列里来, for (int i = 0; i < index.length; i++) { list_se_tnode.add(Begin.list_sec_node.get(index[i])); } this.divide = 0.4 ;// 参数可改 this.parameter = new double[list_se_tnode.size()]; } double price = 0; // 这个函数用来计算第二层的加权相加的结果;(这里把上一层所有的节点乘以权值,全部相加) double price() { price = 0; for (int i = 0; i < list_se_tnode.size(); i++) { price += list_se_tnode.get(i).result * parameter[i]; // System.out.println("System.out.println(list_se_tnode.get(i).result);"+list_se_tnode.get(i).result); } // System.out.println(Arrays.toString(parameter)); price=price/list_se_tnode.size(); return price; } boolean result; boolean result() { if (price > divide) { return result = true; } else { return result = false; } } boolean getresult() { price(); result(); return result; } // 这里调整参数的过程中,注意有用的节点要调高参数,而没用的节点要调低参数(二者同时进行,因为我们的训练的时候,只告诉电脑哪些是“0”,而没有说过哪些不是“0”) void lose(int m) { // 确定第三层节点的期望值和真实值的误差,然后修正参数 // if (result) {// result为真 ,表示测试结果为 当前图形和期望值相同 if (m_expect == m) {// 说明预测正确,要把参数调高到使第二层节点的result和参数相乘刚好接近于0.95; for (int i = 0; i < parameter.length; i++) { // 这里的误差只有 “错误”和“正确”两种,这时候,就只能规定一个误差修正值 parameter[i] -= 0.01 * ((list_se_tnode.get(i).result * parameter[i]) - 0.95); //FIXME 这里注意 当result特别小的时候,第三层节点的参数就会特别大(因为它是在向0.95靠拢) } } else {// 则说明预测错误,要把参数调高到使第二层节点的result和参数相乘刚好接近于0.65; for (int i = 0; i < parameter.length; i++) { // 预测错误,则所有参数 parameter[i] -= 0.01 * ((list_se_tnode.get(i).result * parameter[i]) - 0.65); } } // } else {// result 为假,表示测试结果为当前图形和期望的数字差异较大; // if (m_expect == m) {// 如果相同,说明判断错误 // for (int i = 0; i < parameter.length; i++) { // // 这里的误差只有 “错误”和“正确”两种,这时候,就只能规定一个误差修正值 // parameter[i] -= 0.1 * (list_se_tnode.get(i).result * parameter[i]) - 0.95; // } // } // } } void learn(int m) { getresult(); lose(m); } }
以上代码以识别数字为例。(识别字母、或者汉字的方法同理。)
这里先解释一下,这里的各层节点的用处,
第二层节点:用来捕捉用户书写的图形的各种特点,当捕捉到相应的特点以后,会向第三层节点发出一个0~1之间的信号值(图形和相应的节点所示的特点吻合的越好,那么返回的这个值就越接近1)。比如说“0”这个图形,我们可把它分成很多个部分,(最简单的一种分法,从中间劈开,分成左半部分和右半部分,),其他字符也同理。
第三层节点:表示一个固定的字符,用来把第二层节点传入的值进行加权累加(相应的意义在于把各个特定按照不同的重要度(即权值),加在一起,所加得到的结果越高,就说明用户的所写的这个图形符号和我们第三层相应节点的拟合度越高。),然后第三层节点会把这个加权累加的结果传到第四层节点处。
第四层节点:表示神经网络对图形的分析结果,第四层节点其实只需要定义一个节点就可以满足需要,这个节点将第三层节点传回的值进行排序(第n个节点传回的值越大,就说明这个图形和我们第三层的第n个节点对应的值越相近。所以从大到小排列这些结果,我们就得到了用户输入的图形最有可能表示的字符,第二有可能表示的字符,第三有可能表示的字符······)。
值得说明的是:因为我们这个程序的参数较多(各层自由参数一共有大约一千个左右。),所以我们在调试的时候可以尝试创建一个窗口,参数的值按照某一个顺序,排列出来,这样有助于我们找到出错的地方。
相关文章推荐
- 神经网络学习(六)MNIST手写字识别 --- Matlab实现
- tensorflow 学习笔记12 循环神经网络RNN LSTM结构实现MNIST手写识别
- Python实现深度学习之-神经网络识别手写数字(更新中,更新日期:2017-07-12)
- 深度学习-传统神经网络使用TensorFlow框架实现MNIST手写数字识别
- 神经网络学习(七)MNIST手写字识别 --- Python实现
- tensorflow 学习笔记7 普通神经网络实现mnist手写识别
- 手把手入门神经网络系列(2)_74行代码实现手写数字识别
- 基于神经网络和遗传算法的【手写数字识别】机器人的实现
- 神经网络 手写识别例子 matlab实现
- NN:神经网络实现识别手写的10个数字—Jason niu
- 手把手入门神经网络系列(2)_74行代码实现手写数字识别
- 神经网络实现手写字符识别系统
- C++从零实现深度神经网络之六——实战手写数字识别(sigmoid和tanh)
- 手把手入门神经网络系列(2)_74行代码实现手写数字识别
- 机器深度学习笔记(1)——神经网络从一张图片中识别狗的过程
- 手把手入门神经网络系列(2)_74行代码实现手写数字识别
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- python在线神经网络实现手写字符识别系统
- 神经网络与深度学习笔记——第1章 使用神经网络识别手写数字
- 神经网络与深度学习 1.6 使用Python实现基于梯度下降算法的神经网络和MNIST数据集的手写数字分类程序