您的位置:首页 > 编程语言

Spark MLlib 编程

2016-04-22 11:57 483 查看

数据集的构造

val rawData = sc.textFile("...")
val data = rawdata.map{ line =>
val row = line.split(',').map(_.toDouble)
val featVec = Vectors.dense(row.init)
val label = row.last
LabeledPoint(label, featVec)
}


训练集,交叉验证集(CV),测试集的构造

val Array(trainData, cvData, testData) =
data.randomSplit(Array(.8, .1, .1))
trainData.cache()
cvData.cache()
testData.cache()


模型训练与模型评价(metric)

MultiClassMetrics

BinaryClassificationMetrics

def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]) = {
val predsAndLabels = data.map(sample =>
(model.predict(sample.features), sample.label))
new MultiClassMetrics(predsAndLabels)
}

val model = DecisionTree.trainClassifier(trainData, numClasses, Map[Int, Int](), "gini")
val metrics = getMetrics(model, cvData)


统计样本集的类别分布

def classProb(data: RDD[LabeledPoint]) = {
val countsByCategory = data.map(_.label).countByValue()
val counts = countsByCategory.toArray().sortBy(_._1).map(_._2)
counts.map(_.toDouble/counts.sum)
}


超参的确定(在CV上进行评估)

val evaluations =
for ( impurity <- Array("gini", "entropy");
depth <- Array(1, 20);
bins <- Array(10, 300)
)
yield {
val model = DecisionTree.trainClassifier(trainData, numClasses.toInt, Map[Int, Int](), impurity, depth, bins)
val predsAndLabels = cvData.map(sample => (model.predict(sample.features), sample.label))
val accuracy = new MultiClassMetrics(predsAndLabels)
((impurity, depth, bins), accuracy)
}

evaluations.sotyBy(_._2).reverse.foreach(println)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: