您的位置:首页 > 编程语言 > Java开发

Spark机器学习之协同过滤算法使用-Java篇

2017-08-30 17:38 417 查看
协同过滤通常用于推荐系统,这些技术旨在填补用户和项目关联矩阵里面缺少的值。Spark目前实现基于模型的协同过滤,其中模型的用户和项目由一组小的潜在因素所描述,可用于预测缺少的值。Spark使用交替最小二乘法alternating least squares(ALS)算法来学习这些潜在因素。

1. ALS的参数
numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
rank:模型中潜在因素的数值(默认值为10)
maxIter:要运行的最大迭代次数(默认值为10)
regParam:指定的正则化参数(默认值为1.0)
implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
2. 冷启动(Cold-start)策略

当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的

默认地,当模型中不存在的用户和/或项目因素时,Spark在调用ALSModel.transform方法时,预测的值会是NaN。这在生产系统中可以是有用的,因为NaN表示一个新的用户或项目,因此系统可以预测作出一些回退的决定。

然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。

Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。

注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。

3. Java代码例子

本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:
http://files.grouplens.org/datasets/movielens/ml-100k.zip
import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaALSExample {

public static class Rating implements Serializable {

private static final long serialVersionUID = 1L;
private int userId;
private int movieId;
private float rating;
private long timestamp;

public Rating() {
}

public Rating(int userId, int movieId, float rating, long timestamp) {
this.userId = userId;
this.movieId = movieId;
this.rating = rating;
this.timestamp = timestamp;
}

public int getUserId() {
return userId;
}

public int getMovieId() {
return movieId;
}

public float getRating() {
return rating;
}

public long getTimestamp() {
return timestamp;
}

public static Rating parseRating(String str) {
String[] fields = str.split("\\t");
if (fields.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int movieId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, movieId, rating, timestamp);
}
}

public static void main(String[] args) {
// 测试数据文件路径
String path = "ml-100k/u.data";
// 使用本地所有可用线程local[*]
SparkSession spark = SparkSession.builder().master("local[*]").appName("JavaALSExample").getOrCreate();
JavaRDD<Rating> ratingsRDD = spark.read().textFile(path).javaRDD().map(Rating::parseRating);
Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
// 按比例随机拆分数据
Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];

// 对训练数据集使用ALS算法构建建议模型
ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId")
.setRatingCol("rating");
ALSModel model = als.fit(training);

// Evaluate the model by computing the RMSE on the test data
// 通过计算均方根误差RMSE(Root Mean Squared Error)对测试数据集评估模型
// 注意下面使用冷启动策略drop,确保不会有NaN评估指标
model.setColdStartStrategy("drop");
Dataset<Row> predictions = model.transform(test);

// 打印predictions的schema
predictions.printSchema();

// predictions的schema输出
// root
// |-- movieId: integer (nullable = false)
// |-- rating: float (nullable = false)
// |-- timestamp: long (nullable = false)
// |-- userId: integer (nullable = false)
// |-- prediction: float (nullable = true)

RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating")
.setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
// 打印均方根误差
System.out.println("Root-mean-square error = " + rmse);
}

}


打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:

// Generate top 10 movie recommendations for each user
Dataset<Row> userRecs = model.recommendForAllUsers(10);

// Generate top 10 user recommendations for each movie
Dataset<Row> movieRecs = model.recommendForAllItems(10);


* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html

END O(∩_∩)O
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息