Java实现LSTM和GRU做分类(以IRIS数据集为例)
2017-04-25 12:01
1021 查看
笔者想在JAVA项目中做机器学习的分类想使用循环神经网络的时候苦于没有找到开源的代码,最后终于找到lipiji所写的LSTM和GRU,项目GitHub链接在这:项目GitHub地址,但是这个项目的demo只是简单的做了一个文本序列的预测,无法达到自己做分类的目的,于是笔者新写了一个demo来实现分类的目的,这里所使用的数据集是Iris。Iris数据集是常用的分类实验数据集,由Fisher,
1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。
数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:
package com.lipiji.mllib.rnn.gru;
import com.lipiji.mllib.layers.MatIniter;
import com.lipiji.mllib.rnn.lstm.Cell;
import com.lipiji.mllib.rnn.lstm.LSTM;
import com.lipiji.mllib.utils.LossFunction;
import org.jblas.DoubleMatrix;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
public class gruTest {
public static double train_x[][] = new double[105][4];
public static double test_x[][] = new double[45][4];
public static double train_y[][] = new double[105][3];
public static double test_y[][] = new double[45][3];
private static GRU gru;
public static void main(String[] args) {
loadData();
int hiddenSize = 4;//隐含层数量
double lr = 0.1;
gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层
for (int i = 0; i < 2000; i++) {//迭代2000次
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
Map<String, DoubleMatrix> acts = new HashMap<>();
for (int s = 0; s < train_x.length; s++) {
double newx[][] = new double[1][4];
newx[0] = train_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵
//System.out.println(xt.getColumns()+" "+xt.getRows());
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = train_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
gru.bptt(acts, train_x.length-1, lr);
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}//结束迭代
//开始测试
int num = 0,error = 0;
Map<String, DoubleMatrix> acts = new HashMap<>();
for(int s = 0; s<test_x.length;s++){
double newx[][] = new double[1][4];
newx[0] = test_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = test_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
System.out.println("错误数:"+error+"/"+num);
}
public static void loadData(){
List<String> list = readFileForList("data/train.txt");//训练集
for(int i = 0;i<list.size();i++){
String str[] = list.get(i).split(",");
for(int k = 0 ; k < 4;k++)
train_x[i][k]=Double.valueOf(str[k]);
train_y[i][Integer.valueOf(str[4])] = 1;//将所属类别设置为1
}
list = readFileForList("data/test.txt");//测试集
for(int i = 0;i<list.size();i++){
String str[] = list.get(i).split(",");
for(int k = 0 ; k < 4;k++)
test_x[i][k]=Double.valueOf(str[k]);
test_y[i][Integer.valueOf(str[4])] = 1;
}
}
public static List<String> readFileForList(String fileName) {//读取文件到list
File file = new File(fileName);
BufferedReader reader = null;
List<String> s = new ArrayList<String>();
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
while ((tempString = reader.readLine()) != null) {
// 显示行号
s.add(tempString);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
return s;
}
}
}
实验最终结果如下,可以看到45个测试集对了44个,笔者这里怎么调都无法达到完全的准确率,希望有做出来的可以告知一下,感谢。最后感谢lipiji提供的算法代码。
1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。
数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:
package com.lipiji.mllib.rnn.gru;
import com.lipiji.mllib.layers.MatIniter;
import com.lipiji.mllib.rnn.lstm.Cell;
import com.lipiji.mllib.rnn.lstm.LSTM;
import com.lipiji.mllib.utils.LossFunction;
import org.jblas.DoubleMatrix;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
public class gruTest {
public static double train_x[][] = new double[105][4];
public static double test_x[][] = new double[45][4];
public static double train_y[][] = new double[105][3];
public static double test_y[][] = new double[45][3];
private static GRU gru;
public static void main(String[] args) {
loadData();
int hiddenSize = 4;//隐含层数量
double lr = 0.1;
gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层
for (int i = 0; i < 2000; i++) {//迭代2000次
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
Map<String, DoubleMatrix> acts = new HashMap<>();
for (int s = 0; s < train_x.length; s++) {
double newx[][] = new double[1][4];
newx[0] = train_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵
//System.out.println(xt.getColumns()+" "+xt.getRows());
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = train_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
gru.bptt(acts, train_x.length-1, lr);
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}//结束迭代
//开始测试
int num = 0,error = 0;
Map<String, DoubleMatrix> acts = new HashMap<>();
for(int s = 0; s<test_x.length;s++){
double newx[][] = new double[1][4];
newx[0] = test_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = test_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
System.out.println("错误数:"+error+"/"+num);
}
public static void loadData(){
List<String> list = readFileForList("data/train.txt");//训练集
for(int i = 0;i<list.size();i++){
String str[] = list.get(i).split(",");
for(int k = 0 ; k < 4;k++)
train_x[i][k]=Double.valueOf(str[k]);
train_y[i][Integer.valueOf(str[4])] = 1;//将所属类别设置为1
}
list = readFileForList("data/test.txt");//测试集
for(int i = 0;i<list.size();i++){
String str[] = list.get(i).split(",");
for(int k = 0 ; k < 4;k++)
test_x[i][k]=Double.valueOf(str[k]);
test_y[i][Integer.valueOf(str[4])] = 1;
}
}
public static List<String> readFileForList(String fileName) {//读取文件到list
File file = new File(fileName);
BufferedReader reader = null;
List<String> s = new ArrayList<String>();
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
while ((tempString = reader.readLine()) != null) {
// 显示行号
s.add(tempString);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
return s;
}
}
}
实验最终结果如下,可以看到45个测试集对了44个,笔者这里怎么调都无法达到完全的准确率,希望有做出来的可以告知一下,感谢。最后感谢lipiji提供的算法代码。
相关文章推荐
- 85、使用TFLearn实现iris数据集的分类
- iris数据集 决策树实现分类并画出决策树
- python 实现 knn分类算法 (Iris 数据集)
- java实现k-means算法(用的鸢尾花iris的数据集,从mysq数据库中读取数据)
- Python 3实现k-邻近算法以及 iris 数据集分类应用
- c#神经网络,实现对Iris数据集进行分类
- java实现Knn算法,用iris数据集进行验证
- 85、使用TFLearn实现iris数据集的分类
- [Java][机器学习]用决策树分类算法对Iris花数据集进行处理
- 鸢尾花分类算法实现 java
- 从实现技术给java报表工具分类
- RGB,CMY(K),YUV,YIQ,YCbCr颜色的转换算法(java实现) 分类: Android JAVA 2015-06-08 19:30 26人阅读 评论(0) 收藏
- Java生成二维码实现扫描次数统计并转发到某个地址 分类: 二维码 Java 2015-01-08 10:38 407人阅读 评论(0) 收藏
- 利用BP神经网络分类iris数据集
- Zookeeper实现服务上下线监控服务列表 分类: hadoop Java 2015-06-25 22:37 71人阅读 评论(0) 收藏
- JAVA实现CRC16算法 分类: Android JAVA 2015-03-30 18:58 48人阅读 评论(0) 收藏
- java代码实现商品类别的无限级分类显示
- 【JAVA实现】K-近邻(KNN)分类算法
- java实现FTP传输文件( 2008-05-23 15:45:09| 分类: java 技术)
- 贝叶斯文本分类 java实现