您的位置:首页 > 编程语言 > Java开发

java、scala实现FTRL模型(数据以及工程实现)

2016-12-23 11:57 381 查看
最近同事分享了一个最优化的相关的文档,其中涉及FTRL模型,下面主要说说FTRL模型的实现,其本质还是逻辑回归模型,但是在进行求解进行了改变,

理论主要参考:在线最优化求解(Online Optimization)-冯扬所分享的最优化文档

工程篇主要参考:https://github.com/datawlb/code/tree/master/all

java代码实现:

package com.wanda.rocket_zlast;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStreamReader;
import java.util.HashSet;

public class FTRLModel {
private static final String tagid = "TAGID";
private static final String hour = "HOUR";
private static final String province = "PROVINCE";
private static final String meidia = "MEDIA";
private static final String[] FEATURE = { tagid, hour, province, meidia };

private double alpha = 0.1; // learning rate
private double belta = 1; // smoothing rate
private double L1 = 1; //
private double L2 = 1; //
private int D = 1000000; // number of weights
// private int epoch = 1; // repeat train times

private double[] N;
private double[] Z;
private double[] W;

// z更新权重用到的,存放梯度累加的n,最后的w
public FTRLModel(int D) {
this.D = D;
this.N = new double[D]; // 存放累加的w
this.Z = new double[D]; // sum of the gradient^2
this.W = new double[D]; // 模型最后的参数
}

/**
*
* @param set
*            : hash trick
* @param label
*            : click=1,unclick=0
*/
public void train(HashSet<Integer> set, int label) {
Double p = 0.0;
for (Integer i : set) {
int sign = Z[i] < 0 ? -1 : 1;
if (Math.abs(Z[i]) <= L1) {
W[i] = 0.0;
} else {
W[i] = (sign * L1 - Z[i])
/ ((belta + Math.sqrt(N[i])) / alpha + L2);
}
p += W[i];
}

// predict
p = 1 / (1 + Math.exp(-p));

// update
Double g = p - label;
for (Integer i : set) {
Double sigma = (Math.sqrt(N[i] + g * g) - Math.sqrt(N[i])) / alpha;
Z[i] += g - sigma * W[i];
N[i] += g * g;
}
set.clear();
}

public double predict(HashSet<Integer> set) {
Double p = 0.0;
for (Integer i : set) {
p += W[i];
}
// predict
p = 1 / (1 + Math.exp(-p));
return p;
}

public double logloss(double p, int label) {
if (label == 1) {
p = -Math.log(p);
} else {
p = -Math.log(1.0 - p);
}
return p;
}

public static void main(String args[]) {
Double p = 0.0;

int epoch = 1; // repeat train times

FTRLModel ftrl = new FTRLModel(100000);

String trPath = "D:\\tmp\\ftrl\\train\\part-r-00298";
String tePath = "D:\\tmp\\ftrl\\test\\part-r-00299";
String submissionPath = "D:\\tmp\\ftrl\\result\\result.csv";

BufferedReader br;
String str = null;
// train model
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(
trPath), "UTF-8"));
str = br.readLine();
String name[] = str.split(",");
String value[] = null;
HashSet<Integer> set = new HashSet<Integer>();
for (int epo = 0; epo < epoch; epo++) {
while ((str = br.readLine()) != null) {
value = str.split(",");
set.add(Math.abs((FEATURE[0] + "_" + name[0]).hashCode())
% ftrl.D);
set.add(Math.abs((FEATURE[1] + "_" + name[1]).hashCode())
% ftrl.D);

set.add(Math.abs((FEATURE[2] + "_" + name[3]).hashCode())
% ftrl.D);

set.add(Math.abs((FEATURE[3] + "_" + name[6]).hashCode())
% ftrl.D);
ftrl.train(set, Integer.parseInt(value[10]));
set.clear();
}
}
} catch (Exception e) {
e.printStackTrace();
}

// predict result
BufferedOutputStream bos;
String string = null;
byte[] newLine = "\r\n".getBytes();
int count = 0;
try {
bos = new BufferedOutputStream(new FileOutputStream(submissionPath));
bos.write(("true,ftrl,logisitc").getBytes());
bos.write(newLine);

br = new BufferedReader(new InputStreamReader(new FileInputStream(
tePath), "UTF-8"));
string = br.readLine();
String name[] = string.split(",");
String value[] = null;
HashSet<Integer> set = new HashSet<Integer>();

while ((string = br.readLine()) != null) {
count++;
value = string.split(",");

// private static final String[]
// FEATURE={tagid,hour,province,url};
set.add(Math.abs((FEATURE[0] + "_" + name[0]).hashCode())
% ftrl.D);
set.add(Math.abs((FEATURE[1] + "_" + name[1]).hashCode())
% ftrl.D);

set.add(Math.abs((FEATURE[2] + "_" + name[3]).hashCode())
% ftrl.D);

set.add(Math.abs((FEATURE[3] + "_" + name[6]).hashCode())
% ftrl.D);

p = ftrl.predict(set);
String result = name[10] + "," + p+","+name[name.length-1];
bos.write(result.getBytes());
bos.write(newLine);
set.clear();
}

bos.flush();
bos.close();
System.out.println(count);

} catch (Exception e) {
e.printStackTrace();

}
}
}


scala代码:

package com.buzzinate.bidding.model

import java.io.BufferedReader
import java.io.InputStreamReader
import java.util.Date
import java.util.HashMap
import org.apache.hadoop.fs.FileSystem
import com.alibaba.fastjson.JSON
import com.alibaba.fastjson.JSONObject
import com.buzzinate.bidding.model.util.Loggable
import com.buzzinate.bidding.model.util.MathUtil
import com.buzzinate.bidding.model.util.PathUtil
import com.buzzinate.buzzads.bidding.redis.json.JLogisticModel
import com.buzzinate.buzzads.bidding.util.Constants
import com.buzzinate.buzzads.bidding.util.DomainNames
import DateHandler.conf
import scala.collection.mutable.ListBuffer
import com.buzzinate.bidding.model.util.TimeUtil
import java.text.SimpleDateFormat
import org.apache.hadoop.fs.Path
import javax.activation.DataHandler
import com.amazonaws.util.json.JSONObject.Null
import scala.collection.mutable.ArrayBuffer
import Array._
import org.apache.hadoop.hdfs.protocol.FSLimitException.MaxDirectoryItemsExceededException
import java.io.PrintWriter
import java.io.File

object FTRL extends Loggable {
import DateHandler._
import com.buzzinate.bidding.model.log.AdClickLog._

val featureList = DateHandler.prop.getIntList("feature.list")
val maxDimensions = DateHandler.prop.getInt("max.dimensions", math.pow(2, 20).toInt)
val trainDate = DateHandler.prop.getString("train.date")
val predictDate = DateHandler.prop.getString("predict.date")
val trainDay = DateHandler.prop.getInt("logistic.train.defaultkeepday", 0)
val alpha = DateHandler.prop.getDouble("alpha", 0.1)
val beta = DateHandler.prop.getDouble("beta", 1.0)
val L1 = DateHandler.prop.getDouble("L1", 1.0)
val L2 = DateHandler.prop.getDouble("L2", 1.0)
var n, z, w = new Array[Double](maxDimensions)//z更新权重用到的,存放梯度累加的n,最后的w
def predict(x: Array[Int]): Double = {
var wTx = 0.0

x foreach { x =>
val sign = if (z(x) < 0) -1.0 else 1.0

if (sign * z(x) <= L1)
w(x) = 0.0
else
w(x) = (sign * L1 - z(x)) / ((beta + math.sqrt(n(x))) / alpha + L2)

wTx = wTx + w(x)
}
return 1.0 / (1.0 + math.exp(-math.max(math.min(wTx, 35.0), -35.0)))
}

def update(x: Array[Int], p: Double, y: Int): Unit = {
val g = p - y

x foreach { x =>
val sigma = (math.sqrt(n(x) + g * g) - math.sqrt(n(x))) / alpha
z(x) = z(x) + g - sigma * w(x)
n(x) = n(x) + g * g
}
}

def train(dateStr: String): Unit = {
var trainPaths = PathUtil.getLastExistedAdClickLogPath(DateHandler.adClickLogParentPath, trainDay, trainDay, trainDate)
var testPaths = PathUtil.getLastExistedAdClickLogPath(DateHandler.adClickLogParentPath, 1, 1, predictDate)

info("begin training")
for (i <- 1 to 1) {
info("the " + i.toString() + " iteration")
DateHandler.foreach(trainPaths, { line =>
val log = parseClickLog(line)
if (log.isDefined) {
val x = generateInstance(log.get, featureList)
val p = predict(x)
update(x, p, if (log.get.isClick) 1 else 0)
}
})
}

info("begin predicting")

val writer = new PrintWriter(new File("ctrout.dat"))
var k = 0
DateHandler.foreach(testPaths, { line =>
val log = parseClickLog(line)
if (log.isDefined) {
val x = generateInstance(log.get, featureList)
val p = predict(x)
if (k % 10 == 1 ) {
writer.println(p.toString() + "," + log.get.predictionCtr.toString() + "," + log.get.isClick)
}
update(x, p, if (log.get.isClick) 1 else 0)
k = k + 1
}
})

writer.close()

println("predict end")

val notzerocnt = w.count(x => x !=0)
info("not zero count:" + notzerocnt.toString())
}

def generateInstance(log: AdClickLog, featureList: List[Int]): Array[Int] = {
val instance: ListBuffer[Int] = ListBuffer()

for (feature <- featureList) {
var featureValue = feature match {
case Constants.AD_CATEGORY_FEATURE   => "adCategory_" + log.adCategory
case Constants.BROWSER_FEATURE       => "browser_" + log.browser
case Constants.DOMAIN_FEATURE        => "url_" + DomainNames.safeGetHost(log.url)
case Constants.MEDIA_SLOT_FEATURE    => "slotTagId_" + log.slotTagId
case Constants.OS_FEATURE            => "os_" + log.os
case Constants.PROVINCE_FEATURE      => "province_" + log.province
case Constants.SLOT_POSITION_FEATURE => "adPos_" + log.adPos
case Constants.SLOT_SIZE_FEATURE     => "adSize_" + log.adSize
case Constants.SLOT_TYPE_FEATURE     => "adType_" + log.adType
case Constants.TIME_SLOT_FEATURE     => "timeSlot_" + log.timeSlot
}
instance.append(
math.abs(featureValue.hashCode()) % maxDimensions)
}
instance.toList.toArray
}

}



数据是公司真实点击率预测数据,公司ctr模型是用的logistics回归模型,数据格式如下:


iqiyi_52_1pmSLBNz,4-11,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896

iqiyi_4613_CYYWXOsz,4-11,?,GD,?,71522,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.021587590637397437

iqiyi_11023_LBfNHlMl,4-9,?,HL,?,71518,iqiyi1000000000381,12,6,32,0,60812,www.iqiyi.com,0.016689087173704896

最后根据生成的csv文件在python下roc曲线,auc值还是提升不少。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐