您的位置:首页 > 其它

使用Spark MLlib 完成新闻自动分类

2017-06-19 17:54 495 查看

写在前面

最近学习了一点文本挖掘相关知道,刚刚接触到一点皮毛,刚好学了点Spark,所有就找个了小例子玩了一下,算法和实现都不太难,比较适合看公式一脸蒙逼,无聊想来点实际性Demo玩一下

基本流程



如图所示为新闻自己分类的基本流程,其中主要包含以下几点

语料

分类语料库用搜狗实验室http://www.sogou.com/labs/的数据,语料库中共10个分类,总计50多万条记录,每条记录由对应的分类编号加对应关键词组成,语料库中格式如下

0,苹果 官网 苹果 宣布 ...
1,苹果 梨 香蕉 ...

其中

0 汽车
1 财经
2 IT
3 健康
4 体育
5 旅游
6 教育
7 招聘
8 文化
9 军事


TF-IDF

TF-IDF这个特征算法是比较简单的,用来简单提取特征值学习一下还是可以的,具体算法可以百度一下,Spark 官网也有介绍:中文英文博客

朴素贝叶斯分类器

朴素贝叶斯分类器主要根据贝叶斯概率公式计算事件之间的概率,基本算法原理可以参考博客

Spark 教程 英文中文

新闻数据

这里的新闻数据是用来分类的,可以从互联网上爬取,我里我自己准备了点数据,数据以JSON格式存储,格式如下

{"topicurl":"http://zzhz.zjol.com.cn/system/2017/06/08/021530999.shtml","is_topic":"0","newsid":"021530999","sub_title":"http://xinpan.zzhz.zjol.com.cn/zhhq/20170604/","pub_time":"2017-06-08 14:53","source":"","title":"点评:6月4日,杭州主城区商品房共成交69套。截至4日22:00,主城区可售房源为40325套。"}
{"topicurl":"http://zzhz.zjol.com.cn/system/2017/06/08/021530997.shtml","is_topic":"0","newsid":"021530997","sub_title":"http://xinpan.zzhz.zjol.com.cn/zhhq/20170607/","pub_time":"2017-06-08 14:49","source":"","title":"7日:主城区成交200套 余杭萧山富阳315套"}
{"topicurl":"http://zzhz.zjol.com.cn/system/2017/06/08/021530996.shtml","is_topic":"0","newsid":"021530996","sub_title":"http://xinpan.zzhz.zjol.com.cn/zhhq/20170606/","pub_time":"2017-06-08 14:49","source":"","title":"6日:主城区成交208套 余杭萧山富阳243套"}


文章预处理

这里主要是针对从网上爬过来的新闻数据进行格式转换和分词操作,分词器使用ansj_seg GitHub地址 https://github.com/NLPchina/ansj_seg

经过预处理后,新闻数据就成了一个由关键词组成的文档

主要代码

主流程代码

def main(args: Array[String]): Unit = {

//创建sparkSession
val sparkSession = SparkSession.builder
.config("spark.sql.warehouse.dir", "D:\\WorkSpace\\spark\\spark-learning\\spark-warehouse")
.master("local")
.appName("spark session example")
.getOrCreate()

val trainRdd = sparkSession.sparkContext.textFile("E:\\file\\res\\allType.txt").map(x => {
val data = x.split(",")
(data(0), data(1))
})

//IT-IDF
val trainTFDF = toTFIDF(sparkSession, trainRdd)

//标示点
var trainPoint = trainTFDF.map {
x =>
LabeledPoint(x._1.toDouble, Vectors.dense(x._3.toArray))
}
//训练模型
val model = NaiveBayes.train(trainPoint)

//保存模型数据
// model.save(sparkSession.sparkContext,"E:\\model")
// val model=NaiveBayesModel.load(sparkSession.sparkContext,"E:\\model")

//加载新闻数据
val testData = loadTestData(sparkSession, "E:\\zjol\\21531000.json")
//TF-IDF
val testDataTFIDF = toTFIDF(sparkSession, testData)
//测试分类
val res = testDataTFIDF.map({
x => {
(x._1, model.predict(Vectors.dense(x._3.toArray)))
}
})

//新闻ID,分类

res.foreach(x => println(x._1 + " " + x._2))

}


特征提取

/**
* 对RDD新闻进行TF-IDF特征计算
* @param rdd
* @return
*/
def toTFIDF(sparkSession: SparkSession, rdd: RDD[Tuple2[String, String]]) = {

val df = rdd.map(x => {
Row(x._1, x._2)
})

val schema = StructType(
Seq(
StructField("category", StringType, true)
, StructField("text", StringType, true)
)
)

//将dataRdd转成DataFrame
val srcDF = sparkSession.createDataFrame(df, schema)
srcDF.createOrReplaceTempView("news")

srcDF.select("category", "text").take(2).foreach(println)

//将分好的词按空格拆分转换为DataFrame
var tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
var wordsData = tokenizer.transform(srcDF)

wordsData.select("category", "text", "words").take(2).foreach(println)

val hashingTF = new HashingTF(Math.pow(2, 18).toInt)

val tfDF1 = wordsData.rdd.map(row => {
val words = row.getSeq(2)
(row.getString(0), row.getString(1), hashingTF.transform(words))
})

val tfDF = wordsData.rdd.map(row => {
val words = row.getSeq(2)
hashingTF.transform(words)
})

val idf = new IDF().fit(tfDF)
val num_idf_pairs = tfDF1.map(x => {
(x._1, x._2, idf.transform(x._3))
})

num_idf_pairs.take(10).foreach(println)

num_idf_pairs
}


数据预处理

/**
* 加载测试json新闻数据
* @param sparkSession
* @param path
* @return
*/
def loadTestData(sparkSession: SparkSession, path: String) = {
val df = sparkSession.read.json(path)
df.printSchema()
df.createOrReplaceTempView("news")

val sql = "select author,body,is_topic,keywords,newsid,pub_time,source,sub_title,title,top_title,topicurl from news"

val rdd = sparkSession.sql(sql).rdd.map(row =>
(
row.getString(4).substring(1).toLong,
row.getString(8),
getTextFromTHML(row.getString(6))
)
).filter(x => (!x._2.equals("") && !x._3.equals("") && x._3.length>200 ))

val newsRdd = rdd.map(x => {
val words = ToAnalysis.parse(x._3).getTerms
var string = ""
val size = words.size()
for (i <- 0 until size) {
string += words.get(i.toInt).getName + " "
}
(x._1.toString, string)
})

newsRdd

}
/**
* 抽取HTML中文字
* @param htmlStr
* @return
*/
def getTextFromTHML(htmlStr: String): String = {
val doc = Jsoup.parse(htmlStr)
var text1 = doc.text()
// remove extra white space
val builder = new StringBuilder(text1)
var index = 0
while ( {
builder.length > index
}) {
val tmp = builder.charAt(index)
if (Character.isSpaceChar(tmp) || Character.isWhitespace(tmp)) builder.setCharAt(index, ' ')
index += 1
}
text1 = builder.toString.replaceAll(" +", " ").trim
text1
}


结果

结果数据以文章ID加分类编号组成

21530024 7.0
21530023 6.0
21530022 7.0
21530021 3.0
21530019 7.0
21530018 8.0
21530017 5.0
21530016 3.0
21530015 3.0


21530021 这篇新闻分类为3.0(健康),新闻如下 :



计算正确率

val testRdd = sparkSession.sparkContext.textFile("E:\\file\\res\\test.txt").map(x => {
val data = x.split(",")
(data(0), data(1))
})
//IT-IDF
val testrainTFDF = toTFIDF(sparkSession, testRdd)
//测试分类
val res = testrainTFDF.map({
x => {
(x._1, model.predict(Vectors.dense(x._3.toArray)))
}
})
//新闻ID,分类
res.foreach(x => println(x._1 + " " + x._2))
//新闻总数
val allCount=res.count()
//分类正确数量
val find=res.filter(x=>x._1.toDouble.equals(x._2));
find.foreach(x=>println(x._1+" "+x._2))
//8856 11533
println(find.count()+" "+allCount)


正确率为 76.9%

训练数据及测试数据 链接:http://pan.baidu.com/s/1skKR1GL 密码:qgn2
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息