Spark机器学习之协同过滤算法使用-Java篇
2017-08-30 17:38
417 查看
协同过滤通常用于推荐系统,这些技术旨在填补用户和项目关联矩阵里面缺少的值。Spark目前实现基于模型的协同过滤,其中模型的用户和项目由一组小的潜在因素所描述,可用于预测缺少的值。Spark使用交替最小二乘法alternating least squares(ALS)算法来学习这些潜在因素。
1. ALS的参数
numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
rank:模型中潜在因素的数值(默认值为10)
maxIter:要运行的最大迭代次数(默认值为10)
regParam:指定的正则化参数(默认值为1.0)
implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
2. 冷启动(Cold-start)策略
当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的
默认地,当模型中不存在的用户和/或项目因素时,Spark在调用ALSModel.transform方法时,预测的值会是NaN。这在生产系统中可以是有用的,因为NaN表示一个新的用户或项目,因此系统可以预测作出一些回退的决定。
然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。
Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。
注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。
3. Java代码例子
本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:
http://files.grouplens.org/datasets/movielens/ml-100k.zip
打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:
* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html
END O(∩_∩)O
1. ALS的参数
numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
rank:模型中潜在因素的数值(默认值为10)
maxIter:要运行的最大迭代次数(默认值为10)
regParam:指定的正则化参数(默认值为1.0)
implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
2. 冷启动(Cold-start)策略
当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的
默认地,当模型中不存在的用户和/或项目因素时,Spark在调用ALSModel.transform方法时,预测的值会是NaN。这在生产系统中可以是有用的,因为NaN表示一个新的用户或项目,因此系统可以预测作出一些回退的决定。
然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。
Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。
注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。
3. Java代码例子
本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:
http://files.grouplens.org/datasets/movielens/ml-100k.zip
import java.io.Serializable; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.recommendation.ALS; import org.apache.spark.ml.recommendation.ALSModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; public class JavaALSExample { public static class Rating implements Serializable { private static final long serialVersionUID = 1L; private int userId; private int movieId; private float rating; private long timestamp; public Rating() { } public Rating(int userId, int movieId, float rating, long timestamp) { this.userId = userId; this.movieId = movieId; this.rating = rating; this.timestamp = timestamp; } public int getUserId() { return userId; } public int getMovieId() { return movieId; } public float getRating() { return rating; } public long getTimestamp() { return timestamp; } public static Rating parseRating(String str) { String[] fields = str.split("\\t"); if (fields.length != 4) { throw new IllegalArgumentException("Each line must contain 4 fields"); } int userId = Integer.parseInt(fields[0]); int movieId = Integer.parseInt(fields[1]); float rating = Float.parseFloat(fields[2]); long timestamp = Long.parseLong(fields[3]); return new Rating(userId, movieId, rating, timestamp); } } public static void main(String[] args) { // 测试数据文件路径 String path = "ml-100k/u.data"; // 使用本地所有可用线程local[*] SparkSession spark = SparkSession.builder().master("local[*]").appName("JavaALSExample").getOrCreate(); JavaRDD<Rating> ratingsRDD = spark.read().textFile(path).javaRDD().map(Rating::parseRating); Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class); // 按比例随机拆分数据 Dataset<Row>[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 }); Dataset<Row> training = splits[0]; Dataset<Row> test = splits[1]; // 对训练数据集使用ALS算法构建建议模型 ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId") .setRatingCol("rating"); ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data // 通过计算均方根误差RMSE(Root Mean Squared Error)对测试数据集评估模型 // 注意下面使用冷启动策略drop,确保不会有NaN评估指标 model.setColdStartStrategy("drop"); Dataset<Row> predictions = model.transform(test); // 打印predictions的schema predictions.printSchema(); // predictions的schema输出 // root // |-- movieId: integer (nullable = false) // |-- rating: float (nullable = false) // |-- timestamp: long (nullable = false) // |-- userId: integer (nullable = false) // |-- prediction: float (nullable = true) RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating") .setPredictionCol("prediction"); double rmse = evaluator.evaluate(predictions); // 打印均方根误差 System.out.println("Root-mean-square error = " + rmse); } }
打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:
// Generate top 10 movie recommendations for each user Dataset<Row> userRecs = model.recommendForAllUsers(10); // Generate top 10 user recommendations for each movie Dataset<Row> movieRecs = model.recommendForAllItems(10);
* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html
END O(∩_∩)O
相关文章推荐
- JAVA 使用哈希表操作数据库的例子 Using Hashtables to Store & Extract results from a Database.
- 基于Java字符编码的使用详解
- Java Q&A: 使用Observer模式
- sun在线教材之-java 2d 文本指南-第一课 使用字体
- 使用 .NET实现JavaTM Pet Store J2EETM 蓝图应用程序
- 使用JSP + JAVABEAN + XML 开发的一个例子
- 在xmlspy中使用java的xslt转换
- 使用Java读取Excel文件内容
- 如何使用Java POI生成Excel表文件 !
- JML起步---使用JML 改进你的Java程序(2)
- JML起步---使用JML 改进你的Java程序(3)
- 在ASP中使用简单Java类
- 我学习使用java的一点体会(2)
- 我学习使用java的一点体会(4)
- Java中使用DirectDraw
- 企业内部网中使用Policy文件来设置Java的安全策略
- 我学习使用java的一点体会
- VisualAge for Java使用技巧
- 如何使用Java编写NT服务
- 使用Java实现数据报通讯过程