java、scala实现FTRL模型(数据以及工程实现)
2016-12-23 11:57
381 查看
最近同事分享了一个最优化的相关的文档,其中涉及FTRL模型,下面主要说说FTRL模型的实现,其本质还是逻辑回归模型,但是在进行求解进行了改变,
理论主要参考:在线最优化求解(Online Optimization)-冯扬所分享的最优化文档
工程篇主要参考:https://github.com/datawlb/code/tree/master/all
java代码实现:
scala代码:
数据是公司真实点击率预测数据,公司ctr模型是用的logistics回归模型,数据格式如下:
iqiyi_52_1pmSLBNz,4-11,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896
iqiyi_4613_CYYWXOsz,4-11,?,GD,?,71522,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.021587590637397437
iqiyi_11023_LBfNHlMl,4-9,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896
最后根据生成的csv文件在python下roc曲线,auc值还是提升不少。
理论主要参考:在线最优化求解(Online Optimization)-冯扬所分享的最优化文档
工程篇主要参考:https://github.com/datawlb/code/tree/master/all
java代码实现:
package com.wanda.rocket_zlast; import java.io.BufferedOutputStream; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.InputStreamReader; import java.util.HashSet; public class FTRLModel { private static final String tagid = "TAGID"; private static final String hour = "HOUR"; private static final String province = "PROVINCE"; private static final String meidia = "MEDIA"; private static final String[] FEATURE = { tagid, hour, province, meidia }; private double alpha = 0.1; // learning rate private double belta = 1; // smoothing rate private double L1 = 1; // private double L2 = 1; // private int D = 1000000; // number of weights // private int epoch = 1; // repeat train times private double[] N; private double[] Z; private double[] W; // z更新权重用到的,存放梯度累加的n,最后的w public FTRLModel(int D) { this.D = D; this.N = new double[D]; // 存放累加的w this.Z = new double[D]; // sum of the gradient^2 this.W = new double[D]; // 模型最后的参数 } /** * * @param set * : hash trick * @param label * : click=1,unclick=0 */ public void train(HashSet<Integer> set, int label) { Double p = 0.0; for (Integer i : set) { int sign = Z[i] < 0 ? -1 : 1; if (Math.abs(Z[i]) <= L1) { W[i] = 0.0; } else { W[i] = (sign * L1 - Z[i]) / ((belta + Math.sqrt(N[i])) / alpha + L2); } p += W[i]; } // predict p = 1 / (1 + Math.exp(-p)); // update Double g = p - label; for (Integer i : set) { Double sigma = (Math.sqrt(N[i] + g * g) - Math.sqrt(N[i])) / alpha; Z[i] += g - sigma * W[i]; N[i] += g * g; } set.clear(); } public double predict(HashSet<Integer> set) { Double p = 0.0; for (Integer i : set) { p += W[i]; } // predict p = 1 / (1 + Math.exp(-p)); return p; } public double logloss(double p, int label) { if (label == 1) { p = -Math.log(p); } else { p = -Math.log(1.0 - p); } return p; } public static void main(String args[]) { Double p = 0.0; int epoch = 1; // repeat train times FTRLModel ftrl = new FTRLModel(100000); String trPath = "D:\\tmp\\ftrl\\train\\part-r-00298"; String tePath = "D:\\tmp\\ftrl\\test\\part-r-00299"; String submissionPath = "D:\\tmp\\ftrl\\result\\result.csv"; BufferedReader br; String str = null; // train model try { br = new BufferedReader(new InputStreamReader(new FileInputStream( trPath), "UTF-8")); str = br.readLine(); String name[] = str.split(","); String value[] = null; HashSet<Integer> set = new HashSet<Integer>(); for (int epo = 0; epo < epoch; epo++) { while ((str = br.readLine()) != null) { value = str.split(","); set.add(Math.abs((FEATURE[0] + "_" + name[0]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[1] + "_" + name[1]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[2] + "_" + name[3]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[3] + "_" + name[6]).hashCode()) % ftrl.D); ftrl.train(set, Integer.parseInt(value[10])); set.clear(); } } } catch (Exception e) { e.printStackTrace(); } // predict result BufferedOutputStream bos; String string = null; byte[] newLine = "\r\n".getBytes(); int count = 0; try { bos = new BufferedOutputStream(new FileOutputStream(submissionPath)); bos.write(("true,ftrl,logisitc").getBytes()); bos.write(newLine); br = new BufferedReader(new InputStreamReader(new FileInputStream( tePath), "UTF-8")); string = br.readLine(); String name[] = string.split(","); String value[] = null; HashSet<Integer> set = new HashSet<Integer>(); while ((string = br.readLine()) != null) { count++; value = string.split(","); // private static final String[] // FEATURE={tagid,hour,province,url}; set.add(Math.abs((FEATURE[0] + "_" + name[0]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[1] + "_" + name[1]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[2] + "_" + name[3]).hashCode()) % ftrl.D); set.add(Math.abs((FEATURE[3] + "_" + name[6]).hashCode()) % ftrl.D); p = ftrl.predict(set); String result = name[10] + "," + p+","+name[name.length-1]; bos.write(result.getBytes()); bos.write(newLine); set.clear(); } bos.flush(); bos.close(); System.out.println(count); } catch (Exception e) { e.printStackTrace(); } } }
scala代码:
package com.buzzinate.bidding.model import java.io.BufferedReader import java.io.InputStreamReader import java.util.Date import java.util.HashMap import org.apache.hadoop.fs.FileSystem import com.alibaba.fastjson.JSON import com.alibaba.fastjson.JSONObject import com.buzzinate.bidding.model.util.Loggable import com.buzzinate.bidding.model.util.MathUtil import com.buzzinate.bidding.model.util.PathUtil import com.buzzinate.buzzads.bidding.redis.json.JLogisticModel import com.buzzinate.buzzads.bidding.util.Constants import com.buzzinate.buzzads.bidding.util.DomainNames import DateHandler.conf import scala.collection.mutable.ListBuffer import com.buzzinate.bidding.model.util.TimeUtil import java.text.SimpleDateFormat import org.apache.hadoop.fs.Path import javax.activation.DataHandler import com.amazonaws.util.json.JSONObject.Null import scala.collection.mutable.ArrayBuffer import Array._ import org.apache.hadoop.hdfs.protocol.FSLimitException.MaxDirectoryItemsExceededException import java.io.PrintWriter import java.io.File object FTRL extends Loggable { import DateHandler._ import com.buzzinate.bidding.model.log.AdClickLog._ val featureList = DateHandler.prop.getIntList("feature.list") val maxDimensions = DateHandler.prop.getInt("max.dimensions", math.pow(2, 20).toInt) val trainDate = DateHandler.prop.getString("train.date") val predictDate = DateHandler.prop.getString("predict.date") val trainDay = DateHandler.prop.getInt("logistic.train.defaultkeepday", 0) val alpha = DateHandler.prop.getDouble("alpha", 0.1) val beta = DateHandler.prop.getDouble("beta", 1.0) val L1 = DateHandler.prop.getDouble("L1", 1.0) val L2 = DateHandler.prop.getDouble("L2", 1.0) var n, z, w = new Array[Double](maxDimensions)//z更新权重用到的,存放梯度累加的n,最后的w def predict(x: Array[Int]): Double = { var wTx = 0.0 x foreach { x => val sign = if (z(x) < 0) -1.0 else 1.0 if (sign * z(x) <= L1) w(x) = 0.0 else w(x) = (sign * L1 - z(x)) / ((beta + math.sqrt(n(x))) / alpha + L2) wTx = wTx + w(x) } return 1.0 / (1.0 + math.exp(-math.max(math.min(wTx, 35.0), -35.0))) } def update(x: Array[Int], p: Double, y: Int): Unit = { val g = p - y x foreach { x => val sigma = (math.sqrt(n(x) + g * g) - math.sqrt(n(x))) / alpha z(x) = z(x) + g - sigma * w(x) n(x) = n(x) + g * g } } def train(dateStr: String): Unit = { var trainPaths = PathUtil.getLastExistedAdClickLogPath(DateHandler.adClickLogParentPath, trainDay, trainDay, trainDate) var testPaths = PathUtil.getLastExistedAdClickLogPath(DateHandler.adClickLogParentPath, 1, 1, predictDate) info("begin training") for (i <- 1 to 1) { info("the " + i.toString() + " iteration") DateHandler.foreach(trainPaths, { line => val log = parseClickLog(line) if (log.isDefined) { val x = generateInstance(log.get, featureList) val p = predict(x) update(x, p, if (log.get.isClick) 1 else 0) } }) } info("begin predicting") val writer = new PrintWriter(new File("ctrout.dat")) var k = 0 DateHandler.foreach(testPaths, { line => val log = parseClickLog(line) if (log.isDefined) { val x = generateInstance(log.get, featureList) val p = predict(x) if (k % 10 == 1 ) { writer.println(p.toString() + "," + log.get.predictionCtr.toString() + "," + log.get.isClick) } update(x, p, if (log.get.isClick) 1 else 0) k = k + 1 } }) writer.close() println("predict end") val notzerocnt = w.count(x => x !=0) info("not zero count:" + notzerocnt.toString()) } def generateInstance(log: AdClickLog, featureList: List[Int]): Array[Int] = { val instance: ListBuffer[Int] = ListBuffer() for (feature <- featureList) { var featureValue = feature match { case Constants.AD_CATEGORY_FEATURE => "adCategory_" + log.adCategory case Constants.BROWSER_FEATURE => "browser_" + log.browser case Constants.DOMAIN_FEATURE => "url_" + DomainNames.safeGetHost(log.url) case Constants.MEDIA_SLOT_FEATURE => "slotTagId_" + log.slotTagId case Constants.OS_FEATURE => "os_" + log.os case Constants.PROVINCE_FEATURE => "province_" + log.province case Constants.SLOT_POSITION_FEATURE => "adPos_" + log.adPos case Constants.SLOT_SIZE_FEATURE => "adSize_" + log.adSize case Constants.SLOT_TYPE_FEATURE => "adType_" + log.adType case Constants.TIME_SLOT_FEATURE => "timeSlot_" + log.timeSlot } instance.append( math.abs(featureValue.hashCode()) % maxDimensions) } instance.toList.toArray } }
数据是公司真实点击率预测数据,公司ctr模型是用的logistics回归模型,数据格式如下:
iqiyi_52_1pmSLBNz,4-11,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896
iqiyi_4613_CYYWXOsz,4-11,?,GD,?,71522,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.021587590637397437
iqiyi_11023_LBfNHlMl,4-9,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896
最后根据生成的csv文件在python下roc曲线,auc值还是提升不少。
相关文章推荐
- SpringMVC实现页面和java模型的数据交互以及文件上传下载和数据校验
- Java多线程15:Queue、BlockingQueue以及利用BlockingQueue实现生产者/消费者模型
- java 实现RSA实现数据的私钥加密以及公钥解密
- Greendao 简单实现增删改查使用过GreenDao的同学都知道,3.0之前需要通过新建GreenDaoGenerator工程生成Java数据对象(实体)和DAO对象,非常的繁琐而且也加大了使用成
- 数据结构中怎样用先根和中根以及中根和后根建立一颗二叉树(Java语言实现)
- Java多线程15:Queue、BlockingQueue以及利用BlockingQueue实现生产者/消费者模型
- Java多线程15:Queue、BlockingQueue以及利用BlockingQueue实现生产者/消费者模型
- Java多线程系列-Queue、BlockingQueue以及利用BlockingQueue实现生产者/消费者模型
- Java 实现Excel表数据的读取和写入 以及过程中可能遇到的问题
- 8皇后以及N皇后算法探究,回溯算法的JAVA实现,非递归,数据结构“栈”实现
- 数据挖掘十大经典算法之Apriori算法以及Java实现
- 使用scala,java实现使用phenix读取hbase中数据
- Java多线程15:Queue、BlockingQueue以及利用BlockingQueue实现生产者/消费者模型
- Java程序员从笨鸟到菜鸟之(一百零五)java操作office和pdf文件(三)利用jxl实现数据导出excel报表以及与POI的区别
- elasticsearch查询所有数据restful api以及java代码实现
- 8_14 日学到的新知识(简单的工厂模式的实现, MVC 模式的基本概念,软件工程中的四种开发模型, 以及软件工程中的一些小知识点)
- java实现excel横排以及竖排数据导出
- javaWed项目中用过滤器实现转码功能,敏感词汇过滤更能,处理Get和Post接收数据中的中文乱码问题以及敏感词汇的处理
- java实现发送HTTP的POST请求,返回数据以及请求状态
- 简单通过java的socket&serversocket以及多线程技术实现多客户端的数据的传输,并将数据写入hbase中