多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

栏目: 编程工具 · 发布时间: 1年前

来源: mp.weixin.qq.com

内容简介:01—

本文转载自:https://mp.weixin.qq.com/s/cMuift2-svr0_i0qtUsG6A,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有。

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

01

背景

在之前的文章中介绍了多分类逻辑回归算法的数据原理,参考文章链接

  • CSDN文章链接:

https://blog.csdn.net/Gamer_gyt/article/details/85209496

  • 公众号:

        多分类逻辑回归

该篇文章介绍一下Spark中多分类算法,主要包括的技术点如下:

  • 多分类实现方式

一对一 (One V One)

一对其余(One V Remaining)

多对多 (More V More)

  • Spark中的多分类实现

02

多分类实现方式

一对一

假设某个分类中有N个类别,将这N个类别两两配对(继而转化为二分类问题),这样可以得到 N(N-1)/ 2个二分类器,这样训练模型时需要训练 N(N-1)/ 2个模型,预测时将样本输送到这些模型中,最终统计出现次数较多的类别结果作为最终类别。

假设现在有三个类别:类别A,类别B,类别C,类别D。一对一实现多分类如下图所示:

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

一对多

一对多,即每次把一个类别看做是正类,其余看做负类,此时假设有N个类别,则对应N个分类器,预测时

  • 若只有一个分类器将样本预测为正类,则结果为正类

  • 若只有一个分类器将样本预测为负类,则结果为负类

  • 若预测结果有正类或者负类个数不唯一,则根据概率最大对应的结果作为最终结果

一对多实现多分类如下图所示:

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

多对多

将多个类别作为正类,将多个类别作为负类。显然正反类构造必须有特殊的设计,不能随意选取,在周志华老师的西瓜书中提到了“纠错输出码技术(EOOC)”

EOOC工作主要分为两步:

  • 编码:对N个类别分别做M次划分,每次划分将一部分类别划为正类,一部分划分为负类,从而形成一个二分类分类器,这样一共产生M个训练集,训练出M个分类器

  • 解码:M个分类器分别对测试样本进行预测,这些预测标记组成一个编码,将这个预测编码分别与每个类别各自的编码进行比较,返回其中距离最小的类别作为最终预测结果

多分类实现方式介绍和在 Spark 上实现多分类逻辑回归

上图(a)中,C1类别经过5个分类器后得到的编码为[-1,+1,-1,+1,+1],测试示例经过5个分类器后的编码为[-1,-1,+1,-1,+1],两个编码对比,有三个对应位置不一样,所以海明距离为3,同理可求得测试样例与其他类别的海明距离和欧式距离。

上图(b)中,比图(a)多了0类,即停用类。在计算海明距离时,停用类和测试示例的距离为0.5(笔者认为这里的参数可以进行动态的调整),欧式距离就是正常的当做0值操作。

03

OneVsRest介绍

OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做”One-vs-All”算法。OneVsRest是一个Estimator(评估器)。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。

对应多 “分类实现方式” 中的一对多

04

Spark中的多分类实现

基于ml包中的LogisticRegression实现

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

object MultiClassLR {
   def main(args: Array[String]): Unit = {
      val input_data = "data/sample_multiclass_classification_data.txt" // args(0)

      val spark = SparkSession.builder.master("local[5]").appName("MulticlassLRWithElasticNetExample")
         .getOrCreate()

         runBaseLR(spark,input_data)
      runBaseOneVsRest(spark,input_data)
      spark.stop()
   }

   def runBaseLR(spark: SparkSession, input_data: String): Unit = {
      // 加载训练数据集
      val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
      val train_data = split(0)
      val test_data = split(1)
      // 创建模型
      val lr = new LogisticRegression().setMaxIter(20).setRegParam(0.3).setElasticNetParam(0.8)
      // 训练模型
      val model = lr.fit(train_data)
      // 系数矩阵、截距向量
      println(s"coefficientMatrix is: \n  ${model.coefficientMatrix}")
      println(s"interceptVector is: \n  ${model.interceptVector}")

      // 测试集计算
      val predictions = model.transform(test_data)
      val test_count = test_data.count().toInt
      predictions.take(test_count).foreach(println)
      val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
      val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
      val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
      val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
      val f1=evaluator.setMetricName("f1").evaluate(predictions);

      println(s"accuracy is $accuracy")
      println(s"weightedPrecision is $weightedPrecision")
      println(s"weightedRecall is $weightedRecall")
      println(s"f1 is $f1")
   }
   def runBaseOneVsRest(spark: SparkSession, input_data: String): Unit = {
       // 加载训练数据集
      val split = spark.read.format("libsvm").load(input_data).randomSplit(Array(1,1))
      val train_data = split(0)
      val test_data = split(1)
      // 创建模型
      val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
      // 训练模型
      val model = new OneVsRest().setClassifier(lr).fit(train_data)
         
      // 测试集计算
      val predictions = model.transform(test_data)
      val test_count = test_data.count().toInt
      predictions.take(test_count).foreach(println)
      val evaluator = new MulticlassClassificationEvaluator()//.setLabelCol("label").setPredictionCol("prediction")
      val accuracy =evaluator.setMetricName("accuracy").evaluate(predictions);
      val weightedPrecision=evaluator.setMetricName("weightedPrecision").evaluate(predictions);
      val weightedRecall=evaluator.setMetricName("weightedRecall").evaluate(predictions);
      val f1=evaluator.setMetricName("f1").evaluate(predictions);

      println(s"accuracy is $accuracy")
      println(s"weightedPrecision is $weightedPrecision")
      println(s"weightedRecall is $weightedRecall")
      println(s"f1 is $f1")
   }
}

运行输出信息(runBaseLR)

[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),[0.142419333934195,-0.3772619583435227,0.06140018515891296],[0.3973163522554911,0.23628803028122905,0.3663956174632798],0.0]
[0.0,(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]),[0.21570018327413792,-0.5776524462730686,0.06140018515891296],[0.433024032357093,0.19586792869116354,0.3711080389517435],0.0]
[0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),[0.21570018327413792,-0.6014274236583748,0.06140018515891296],[0.43502594980545906,0.19215033820182928,0.3728237119927117],0.0]
[0.0,(4,[0,1,2,3],[-0.166667,-0.416667,0.38983,0.5]),[0.19127333120195608,-0.5901059157567916,0.06140018515891296],[0.42808566999671693,0.19596657170058288,0.37594775830270016],0.0]
...
accuracy is 0.8615384615384616
weightedPrecision is 0.9017369727047146
weightedRecall is 0.8615384615384616
f1 is 0.8554924320962056
...

运行输出信息(runBaseOneVsRest)

[0.0,(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]),0.0]
[0.0,(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]),0.0]
[0.0,(4,[0,1,2,3],[-0.111111,-0.166667,0.38983,0.416667]),0.0]
...
accuracy is 0.8051948051948052
weightedPrecision is 0.8539693389317449
weightedRecall is 0.8051948051948052
f1 is 0.7797931797931799
...

以上所述就是小编给大家介绍的《多分类实现方式介绍和在 Spark 上实现多分类逻辑回归》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

关注码农网公众号

关注我们,获取更多IT资讯^_^


为你推荐:

相关软件推荐:

查看所有标签

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

数值方法和MATLAB实现与应用

数值方法和MATLAB实现与应用

拉克唐瓦尔德 / 机械工业出版社 / 2004-9 / 59.00元

本书是关于数值方法和MATLAB的介绍,是针对高等院校理工科专业学生编写的教材。数值方法可以用来生成其他方法无法求解的问题的近似解。本书的主要目的是为应用计算打下坚实的基础,由简单到复杂讲述了标准数值方法在实际问题中的实现和应用。本书通篇使用良好的编程习惯向读者展示了如何清楚地表达计算思想及编制文档。书中通过给读者提供大量的可直接运行的代码库以及讲解MARLAB工具箱中内置函数使用的数量方法,帮助......一起来看看 《数值方法和MATLAB实现与应用》 这本书的介绍吧!

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具