tonglin0325的个人主页

Spark学习笔记——基于MLlib的机器学习

使用MLlib库中的机器学习算法对垃圾邮件进行分类

分类的垃圾邮件的如图中分成4个文件夹,两个文件夹是训练集合,两个文件夹是测试集合

build.sbt文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
name := "spark-first"

version := "1.0"

scalaVersion := "2.11.8"

libraryDependencies ++= Seq(
"org.apache.spark" % "spark-core_2.11" % "2.1.0",
"org.apache.hadoop" % "hadoop-common" % "2.7.2",
"mysql" % "mysql-connector-java" % "5.1.31",
"org.apache.spark" %% "spark-sql" % "2.1.0",
"org.apache.spark" %% "spark-streaming" % "2.1.0",
"org.apache.spark" % "spark-mllib_2.11" % "2.1.0"
)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import org.apache.hadoop.io.{IntWritable, LongWritable, MapWritable, Text}
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark._
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.KeyValueTextInputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.sql.SQLContext
import java.util.Properties

import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.Seconds

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD

/**
* Created by common on 17-4-6.
*/
object SparkRDD {

def main(args: Array[String]) {
val conf = new SparkConf().setAppName("WordCount").setMaster("local")
val sc = new SparkContext(conf)

val spam = sc.textFile("input/email/spam")
val normal = sc.textFile("input/email/ham")

// 创建一个HashingTF实例来把邮件文本映射为包含10000个特征的向量
val tf = new HashingTF(numFeatures = 10000)
// 各邮件都被切分为单词,每个单词被映射为一个特征
val spamFeatures = spam.map(email => tf.transform(email.split(" ")))
val normalFeatures = normal.map(email => tf.transform(email.split(" ")))
// 创建LabeledPoint数据集分别存放阳性(垃圾邮件)和阴性(正常邮件)的例子
val positiveExamples = spamFeatures.map(features => LabeledPoint(1, features))
val negativeExamples = normalFeatures.map(features => LabeledPoint(0, features))
val trainingData = positiveExamples.union(negativeExamples)
trainingData.cache() // 因为逻辑回归是迭代算法,所以缓存训练数据RDD
// 使用SGD算法运行逻辑回归
val model = new LogisticRegressionWithSGD().run(trainingData)
// 以阳性(垃圾邮件)和阴性(正常邮件)的例子分别进行测试
val posTest = tf.transform(
"Experience with BiggerPenis Today! Grow 3-inches more ...".split(" "))
val negTest = tf.transform(
"That is cold. Is there going to be a retirement party? ...".split(" "))
println("Prediction for positive test example: " + model.predict(posTest))
println("Prediction for negative test example: " + model.predict(negTest))

}
}

 

结果