您的位置:首页 > 编程语言 > Python开发

python训练模型上线问题总结

2017-10-24 15:01 555 查看

java调用python模型

PMML格式

使用java自带的Runtime.getRuntime().exec(args);方法直接调用python脚本

PMML格式

1、首先将python代码训练的模型保存为pmml格式,代码如下

model = xgb.XGBClassifier()
from sklearn2pmml import PMMLPipeline
pipeline = PMMLPipeline([("classifier", model)])
pipeline.fit(X_train,y_train)
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "xgb.pmml", with_repr = True)


然后使用java读取pmml文件对数据进行预测,

后来选择使用java调用虚拟机的方式运行python脚本。

import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;

/**
* 读取pmml 获取模型
*
* @author liaotuo
*
*/
public class ModelInvoker {
private ModelEvaluator modelEvaluator;

// 通过文件读取模型
public ModelInvoker(String pmmlFileName) {
PMML pmml = null;
InputStream is = null;
try {
if (pmmlFileName != null) {
is = ModelInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName);
pmml = PMMLUtil.unmarshal(is);
}
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
System.out.println("模型读取成功");
}
// 通过输入流读取模型
public ModelInvoker(InputStream is) {
PMML pmml = null;
try {
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
}
public Map<FieldName, ?> invoke(Map<FieldName, Object> paramsMap) {
return this.modelEvaluator.evaluate(paramsMap);
}
}


import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.dmg.pmml.FieldName;

/**
* 使用模型
*
* @author gs
*
*/
public class ModelCalc {

static String pmmlPath = "E:\\workspace\\python\\tydic\\model\\xgb.pmml";
public static void main(String[] args) throws IOException {
String modelArgsFilePath = "E:\\workspace\\python\\tydic\\model\\test\\X_val";

predictFromFile(modelArgsFilePath);
}
/**
* 以文件名的方式读取输入数据进行预测
* @param modelArgsFilePath
* @throws FileNotFoundException
* @throws IOException
*/
public static List<String> predictFromFile(String modelArgsFilePath) throws FileNotFoundException, IOException {

BufferedInputStream bis = new BufferedInputStream(new FileInputStream(pmmlPath));

ModelInvoker invoker = new ModelInvoker(bis);
List<Map<FieldName, Object>> paramList = getDataFromFile(modelArgsFilePath);
List<String> predictResult = new ArrayList<String>();
int lineNum = 0; // 当前处理行数
for (Map<FieldName, Object> param : paramList) {

lineNum++;
System.out.println("======当前行: " + lineNum + "=======");
Map<FieldName, ?> result = invoker.invoke(param);
Set<FieldName> keySet = result.keySet(); // 获取结果的keySet
int i = 0;
for (FieldName fn : keySet) {
String probility1 = result.get(fn).toString(); //预测为1的概率
System.out.println(probility1);
//              i++;
//              if(i%3==0){
//                  predictResult.add(probility1);
//              }

}
}
return predictResult;
}

/**
* 读取参数文件
*
* @param filePath
* @return
* @throws IOException
*/
private static List<Map<FieldName, Object>> getDataFromFile(String filePath) throws IOException {
BufferedReader br = new BufferedReader(new FileReader(filePath));
String[] nameArr = br.readLine().split(" "); // 读取表头的名字
List<Map<FieldName, Object>> list = new ArrayList();
String paramLine = null; // 一行参数
// 循环读取 每次读取一行数据
while ((paramLine = br.readLine()) != null) {

Map<FieldName, Object> map = new HashMap<FieldName, Object>();
String[] paramLineArr = paramLine.split(" ");
// 一次循环处理一行数据
for (int i = 0; i < paramLineArr.length; i++) {
map.put(new FieldName(nameArr[i]), paramLineArr[i]); // 将表头和值组成map
}
list.add(map);                                              // 加入list中

}
return list;
}
}


使用Runtime.getRuntime().exec(args)

这个主要是java代码的书写

public class PythonDemo {
public static void main(String[] args) {
try {
// 需传入的参数
String host = "localhost";
String port = "3306";
String user = "root";
String passwd = "123456";
String path = "C:/";
String database = "dic_coll_consume";
String start_date = "2017-08-01";
String end_date = "2017-09-01";

args = new String[] { "python", "C:\\model_train.py", host, port,user, passwd, path, database, start_date, end_date };
Process pr = Runtime.getRuntime().exec(args);
print(pr.getInputStream());
print(pr.getErrorStream());
} catch (Exception e) {
e.printStackTrace();
}
}

private static String decodeUnicode(String line) {
String l = null;
try {
l = new String(line.getBytes(), "utf8");
} catch (UnsupportedEncodingException e) {
System.out.println("wrong");
e.printStackTrace();
}
return l;
}

private static void print(InputStream stream){
new Thread(new Runnable() {
public void run() {
try{
BufferedReader in = new BufferedReader(new InputStreamReader(stream));
String line;
while ((line = in.readLine()) != null) {
line = decodeUnicode(line);
System.out.println(line);
}
in.close();
System.out
a906
.println("end");
} catch (Exception e) {
e.printStackTrace();
}
}
}).start();
}

}


这个地方主要遇到的坑就是由于程序运行时会弹出很多的信息,而使用pr.waitfor()时缓存很小,很容易使程序阻塞 后来采用多线程的方式将信息打印出来,立马就解决了问题。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: