Spark mllib 决策树

栏目: 服务器 · 发布时间: 5年前

package com.immooc.spark

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils

object DecisionTreeTest {
  def main(args:Array[String]): Unit = {


    val conf = new SparkConf().setAppName("DecisionTreeTest").setMaster("local[2]")
    val sc = new SparkContext(conf)

    Logger.getRootLogger.setLevel(Level.WARN)

    // 读取样本数据1,格式为LIBSVM format
    val data = sc.textFile("file:///Users/walle/Documents/D3/sparkmlib/data.txt")
    val parsedData = data.map{ line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }


    //样本数据划分训练样本与测试样本
    val splits = parsedData.randomSplit(Array(0.7, 0.3), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

     val numClasses = 2
     val categoricalFeaturesInfo = Map[Int, Int]()
     val impurity = "gini"
     val maxDepth = 5
     val maxBins = 32

     val model = DecisionTree.trainClassifier(training, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins)


    //模型预测
    val labelAndPreds = test.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    //测试值与真实值对比
    val print_predict = labelAndPreds.take(15)
    println("label" + "\t" + "prediction")
    for (i <- 0 to print_predict.length - 1) {
      println(print_predict(i)._1 + "\t" + print_predict(i)._2)
    }

    //树的错误率
    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count()
    println("Test Error = " + testErr)
    //打印树的判断值
    println("Learned classification tree model:\n" + model.toDebugString)

  }
}

1. 数据

0,32 1 1 0
0,25 1 2 0
1,29 1 2 1
1,24 1 1 0
0,31 1 1 0
1,35 1 2 1
0,30 0 1 0
0,31 1 1 0
1,30 1 2 1
1,21 1 1 0
0,21 1 2 0
1,21 1 2 1
0,29 0 2 1
0,29 1 0 1
0,29 0 2 1
1,30 1 1 0

2. 结果

label	prediction
1.0	1.0
1.0	1.0
1.0	0.0
0.0	1.0
0.0	0.0
Test Error = 0.4
Learned classification tree model:
DecisionTreeModel classifier of depth 5 with 11 nodes
  If (feature 0 <= 33.5)
   If (feature 0 <= 30.5)
    If (feature 1 <= 0.5)
     Predict: 0.0
    Else (feature 1 > 0.5)
     If (feature 0 <= 27.0)
      If (feature 2 <= 1.5)
       Predict: 1.0
      Else (feature 2 > 1.5)
       Predict: 0.0
     Else (feature 0 > 27.0)
      Predict: 1.0
   Else (feature 0 > 30.5)
    Predict: 0.0
  Else (feature 0 > 33.5)
   Predict: 1.0
4691

以上所述就是小编给大家介绍的《Spark mllib 决策树》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Maven实战

Maven实战

许晓斌 / 机械工业出版社 / 2010年12月 / 65.00元

你是否早已厌倦了日复一日的手工构建工作?你是否对各个项目风格迥异的构建系统感到恐惧?Maven——这一Java社区事实标准的项目管理工具,能帮你从琐碎的手工劳动中解脱出来,帮你规范整个组织的构建系统。不仅如此,它还有依赖管理、自动生成项目站点等超酷的特性,已经有无数的开源项目使用它来构建项目并促进团队交流,每天都有数以万计的开发者在访问中央仓库以获取他们需要的依赖。 本书内容全面而系统,Ma......一起来看看 《Maven实战》 这本书的介绍吧!

html转js在线工具
html转js在线工具

html转js在线工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具