多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

栏目: 编程工具 · 发布时间: 5年前

内容简介:01—

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

01

背景

在之前的文章中介绍了多分类逻辑回归算法的数据原理,参考文章链接

  • CSDN文章链接:

https://blog.csdn.net/Gamer_gyt/article/details/85209496

  • 公众号:

        多分类逻辑回归

该篇文章介绍一下Spark中多分类算法,主要包括的技术点如下:

  • 多分类实现方式

一对一 (One V One)

一对其余(One V Remaining)

多对多 (More V More)

  • Spark中的多分类实现

02

多分类实现方式

一对一

假设某个分类中有N个类别,将这N个类别两两配对(继而转化为二分类问题),这样可以得到 N(N-1)/ 2个二分类器,这样训练模型时需要训练 N(N-1)/ 2个模型,预测时将样本输送到这些模型中,最终统计出现次数较多的类别结果作为最终类别。

假设现在有三个类别:类别A,类别B,类别C,类别D。一对一实现多分类如下图所示:

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

一对多

一对多,即每次把一个类别看做是正类,其余看做负类,此时假设有N个类别,则对应N个分类器,预测时

  • 若只有一个分类器将样本预测为正类,则结果为正类

  • 若只有一个分类器将样本预测为负类,则结果为负类

  • 若预测结果有正类或者负类个数不唯一,则根据概率最大对应的结果作为最终结果

一对多实现多分类如下图所示:

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

多对多

将多个类别作为正类,将多个类别作为负类。显然正反类构造必须有特殊的设计,不能随意选取,在周志华老师的西瓜书中提到了“纠错输出码技术(EOOC)”

EOOC工作主要分为两步:

  • 编码:对N个类别分别做M次划分,每次划分将一部分类别划为正类,一部分划分为负类,从而形成一个二分类分类器,这样一共产生M个训练集,训练出M个分类器

  • 解码:M个分类器分别对测试样本进行预测,这些预测标记组成一个编码,将这个预测编码分别与每个类别各自的编码进行比较,返回其中距离最小的类别作为最终预测结果

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

上图(a)中,C1类别经过5个分类器后得到的编码为[-1,+1,-1,+1,+1],测试示例经过5个分类器后的编码为[-1,-1,+1,-1,+1],两个编码对比,有三个对应位置不一样,所以海明距离为3,同理可求得测试样例与其他类别的海明距离和欧式距离。

上图(b)中,比图(a)多了0类,即停用类。在计算海明距离时,停用类和测试示例的距离为0.5(笔者认为这里的参数可以进行动态的调整),欧式距离就是正常的当做0值操作。

03

OneVsRest介绍

OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做”One-vs-All”算法。OneVsRest是一个Estimator(评估器)。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。

对应多 “分类实现方式” 中的一对多

04

Spark中的多分类实现

基于ml包中的LogisticRegression实现

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

object MultiClassLR {
   def main(args: Array[String]): Unit = {
      val input_data = "data/sample_multiclass_classification_data.txt" // args(0)

      val spark = SparkSession.builder.master("local[5]").appName("MulticlassLRWithElasticNetExample")
         .getOrCreate()

         runBaseLR(spark,input_data)
      runBaseOneVsRest(spark,input_data)
      spark.stop()
   }

   def runBaseLR(spark: SparkSession, input_data: String): Unit = {
      // 加载训练数据集
      val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
      val train_data = split(0)
      val test_data = split(1)
      // 创建模型
      val lr = new LogisticRegression().setMaxIter(20).setRegParam(0.3).setElasticNetParam(0.8)
      // 训练模型
      val model = lr.fit(train_data)
      // 系数矩阵、截距向量
      println(s"coefficientMatrix is: \n  ${model.coefficientMatrix}")
      println(s"interceptVector is: \n  ${model.interceptVector}")

      // 测试集计算
      val predictions = model.transform(test_data)
      val test_count = test_data.count().toInt
      predictions.take(test_count).foreach(println)
      val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
      val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
      val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
      val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
      val f1=evaluator.setMetricName("f1").evaluate(predictions);

      println(s"accuracy is $accuracy")
      println(s"weightedPrecision is $weightedPrecision")
      println(s"weightedRecall is $weightedRecall")
      println(s"f1 is $f1")
   }
   def runBaseOneVsRest(spark: SparkSession, input_data: String): Unit = {
       // 加载训练数据集
      val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
      val train_data = split(0)
      val test_data = split(1)
      // 创建模型
      val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
      // 训练模型
      val model = new OneVsRest().setClassifier(lr).fit(train_data)
         
      // 测试集计算
      val predictions = model.transform(test_data)
      val test_count = test_data.count().toInt
      predictions.take(test_count).foreach(println)
      val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
      val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
      val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
      val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
      val f1=evaluator.setMetricName("f1").evaluate(predictions);

      println(s"accuracy is $accuracy")
      println(s"weightedPrecision is $weightedPrecision")
      println(s"weightedRecall is $weightedRecall")
      println(s"f1 is $f1")
   }
}

运行输出信息(runBaseLR)

[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),[0.142419333934195,-0.3772619583435227,0.06140018515891296],[0.3973163522554911,0.23628803028122905,0.3663956174632798],0.0]
[0.0,(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]),[0.21570018327413792,-0.5776524462730686,0.06140018515891296],[0.433024032357093,0.19586792869116354,0.3711080389517435],0.0]
[0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),[0.21570018327413792,-0.6014274236583748,0.06140018515891296],[0.43502594980545906,0.19215033820182928,0.3728237119927117],0.0]
[0.0,(4,[0,1,2,3],[-0.166667,-0.416667,0.38983,0.5]),[0.19127333120195608,-0.5901059157567916,0.06140018515891296],[0.42808566999671693,0.19596657170058288,0.37594775830270016],0.0]
...
accuracy is 0.8615384615384616
weightedPrecision is 0.9017369727047146
weightedRecall is 0.8615384615384616
f1 is 0.8554924320962056
...

运行输出信息(runBaseOneVsRest)

[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),0.0]
[0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),0.0]
[0.0,(4,[0,1,2,3],[-0.111111,-0.166667,0.38983,0.416667]),0.0]
...
accuracy is 0.8051948051948052
weightedPrecision is 0.8539693389317449
weightedRecall is 0.8051948051948052
f1 is 0.7797931797931799
...

以上所述就是小编给大家介绍的《多分类实现方式介绍和在 Spark 上实现多分类逻辑回归》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

软件测试的艺术

软件测试的艺术

梅尔斯 / 机械工业出版社 / 2006年01月 / 22.0

《软件测试的艺术》(原书第2版)成功、有效地进行软件测试的实用策略和技术:    基本的测试原理和策略      验收测试    程序检查和走查         安装测试    代码检查            模块(单元)测试    错误列表            测试规划与控制    同行评分            独立测试机构    黑盒、白盒测试    ......一起来看看 《软件测试的艺术》 这本书的介绍吧!

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具