Scipy数据函数的Scala实现

栏目: Python · 发布时间: 5年前

内容简介:最近在推进项目的时候,遇到需要将线下的Python代码转化成线上的集群代码,由于机器代码环境是Scala,所以需要将代码翻译一遍,遇到的最大问题就是数学函数。由于Python环境下有很多强大的数学统计和分析的包,所以基本上很多数学函数可在Python包中找到,而Scala相对就弱一些。本次遇到的困难主要是以下三个数学函数:该函数返回的是gamma 函数的对数,即gammaln(A) = log(gamma(A)) 。输入 A 必须是非负数和实数。gammaln 命令可避免直接使用 log(gamma(A)

最近在推进项目的时候,遇到需要将线下的 Python 代码转化成线上的集群代码,由于机器代码环境是Scala,所以需要将代码翻译一遍,遇到的最大问题就是数学函数。由于Python环境下有很多强大的数学统计和分析的包,所以基本上很多数学函数可在Python包中找到,而Scala相对就弱一些。本次遇到的困难主要是以下三个数学函数:

from scipy.special import gammaln
from scipy.special import hyp2f1
from scipy.special import logsumexp

gamma 函数的对数

该函数返回的是gamma 函数的对数,即gammaln(A) = log(gamma(A)) 。输入 A 必须是非负数和实数。gammaln 命令可避免直接使用 log(gamma(A)) 计算时可能会出现的下溢和上溢。

在scipy中其适用的是 Cpython实现 ,基于此代码,可以有的解决方案为:

  • 抽离相关的 C语言 代码,将其编程成.so文件,在Scala中直接调用.so文件中的函数
  • 阅读相关的C语言代码,理解其逻辑,将其转化成Scala代码。

除了上述解决方式外,还可以寻找是否有 现成的Scala代码

import scala.math
import scala.annotation.tailrec
import java.lang.Integer
 
// Adapted from http://www.johndcook.com/stand_alone_code.html
// All bugs are however likely my fault
class Gamma {
  //Entry points
  def gamma(x:Double): Double = {
    val v = hoboTrampoline(x,false,((y: Double) => y))
    v
  }
  def logGamma(x:Double): Double = {
    val v = hoboTrampoline(x,true,((y: Double) => y))
    v
  }
 
  //Since scala doesn't support optimizing co-recursive tail-calls
  //we manually make a trampoline and make it tail recursive
  @tailrec
  private def hoboTrampoline(x: Double, log: Boolean,todo: Double => Double): Double = {
    if (!log) {
        if (x <= 0.0)
        {
            val msg = "Invalid input argument "+x+". Argument must be positive."
            throw new IllegalArgumentException(msg);
        }
 
        // Split the function domain into three intervals:
        // (0, 0.001), [0.001, 12), and (12, infinity)
 
        ///////////////////////////////////////////////////////////////////////////
        // First interval: (0, 0.001)
        //
        // For small x, 1/Gamma(x) has power series x + gamma x^2  - ...
        // So in this range, 1/Gamma(x) = x + gamma x^2 with error on the order of x^3.
        // The relative error over this interval is less than 6e-7.
 
        val gamma: Double = 0.577215664901532860606512090; // Euler's gamma constant
        if (x < 0.001) {
            todo(1.0/(x*(1.0 + gamma*x)));
        } else if (x < 12.0) {
          ///////////////////////////////////////////////////////////////////////////
          // Second interval: [0.001, 12)
          // The algorithm directly approximates gamma over (1,2) and uses
          // reduction identities to reduce other arguments to this interval.
          val arg_was_less_than_one: Boolean = (x < 1.0);
 
          // Add or subtract integers as necessary to bring y into (1,2)
          // Will correct for this below
          val (n: Integer,y: Double) =  if (arg_was_less_than_one)
            {
              (0,x + 1.0)
            } else {
              val n: Integer = x.floor.toInt - 1;
              (n,x-n)
            }
 
          // numerator coefficients for approximation over the interval (1,2)
          val p: Array[Double] =
            Array(
              -1.71618513886549492533811E+0,
              2.47656508055759199108314E+1,
              -3.79804256470945635097577E+2,
              6.29331155312818442661052E+2,
              8.66966202790413211295064E+2,
              -3.14512729688483675254357E+4,
              -3.61444134186911729807069E+4,
              6.64561438202405440627855E+4
            );
 
          // denominator coefficients for approximation over the interval (1,2)
          val q: Array[Double] =
            Array(
              -3.08402300119738975254353E+1,
              3.15350626979604161529144E+2,
              -1.01515636749021914166146E+3,
              -3.10777167157231109440444E+3,
              2.25381184209801510330112E+4,
              4.75584627752788110767815E+3,
              -1.34659959864969306392456E+5,
              -1.15132259675553483497211E+5
            );
 
          val z: Double = y - 1;
          val num = p.foldLeft(0: Double)({(a,b) => (b+a)*z})
          val den = q.foldLeft(1: Double)({(a,b) => a*z+b})
 
          val result = num/den + 1.0;
 
          // Apply correction if argument was not initially in (1,2)
          if (arg_was_less_than_one)
            {
              // Use identity gamma(z) = gamma(z+1)/z
              // The variable "result" now holds gamma of the original y + 1
              // Thus we use y-1 to get back the orginal y.
              todo(result / (y-1.0));
            }
          else
            {
              // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z)
              todo(List.range(0,n.toInt).map(_.toDouble).foldLeft(result)((a,b) => a*(y+b)))
            }
        } else if (x <= 171.624) {
          ///////////////////////////////////////////////////////////////////////////
          // Third interval: [12, 171.624)
          hoboTrampoline(x,true,((a: Double) => todo(math.exp(a))));
        } else {
          ///////////////////////////////////////////////////////////////////////////
          // Fourth interval: [171.624, INFINITY)
          // Correct answer too large to display.
          todo(scala.Double.PositiveInfinity)
        }
    } else {
      //log implementation
      if (x <= 0.0)
        {
          val msg = "Invalid input argument "+x+". Argument must be positive."
          throw new IllegalArgumentException(msg);
        }
 
      if (x < 12.0) {
        hoboTrampoline(x,false,((a: Double) => todo(math.log(math.abs(a)))));
      } else {
 
        // Abramowitz and Stegun 6.1.41
        // Asymptotic series should be good to at least 11 or 12 figures
        // For error analysis, see Whittiker and Watson
        // A Course in Modern Analysis (1927), page 252
 
        val c: Array[Double] =
          Array(
            1.0/12.0,
            -1.0/360.0,
            1.0/1260.0,
            -1.0/1680.0,
            1.0/1188.0,
            -691.0/360360.0,
            1.0/156.0
          );
        val z: Double = 1.0/(x*x);
        val sum: Double = c.foldRight(-3617.0/122400.0: Double)({(a,b) => b*z+a});
        val series: Double = sum/x;
 
        val halfLogTwoPi: Double = 0.91893853320467274178032973640562;
        val logGamma: Double = (x - 0.5)*math.log(x) - x + halfLogTwoPi + series;
        todo(logGamma);
      }
    }
  }
 
 
}
object Gamma extends Gamma

高斯超几何函数2F1(a,b,c,d)

hyp2f1 是Scipy中 高斯超几何函数 的实现。找到了 C语言出处 ,同样也找到了 Scala实现

def hyp2f1 (a: Double, b: Double, c: Double, z: Double): Double =
    {
        val MAX_ITER = 35           // for 9 sig-digits in t-dist with 10 dof
 
        (a, b, c) match {
        case (  _, 1.0, 1.0) => (1.0 / (1.0 - z))~^(-a)
        case (0.5, 0.5, 1.5) => asin (z) / z
        case (1.0, 1.0, 2.0) => log (1.0 - z) / -z
        case (1.0, 2.0, 1.0) => 1.0 / ((1.0 - z) * (1.0 - z))
        case (1.0, 2.0, 2.0) => 1.0 / (1.0 - z)
        case _               => var sum  = 0.0
                                var prod = 1.0
                                for (k <- 0 until MAX_ITER) {
                                    sum  += prod
                                    prod *= z * ((a + k) * (b + k)) / ((c + k) * (k + 1.0))
                                } // for
                                sum
        } // match
} // hyp2f1

上述代码编译时会出错,主要原因是~^运算符不正确,正常情况下~为按位取反,^为按位异或。在这里很难解释清楚。查询了该项目,发现~^是被重新定义的为了 求幂运算符 。类似python中的**。类似实现:

implicit class PowerOp[T: Numeric](value: T) {
    import Numeric.Implicits._
    import scala.math.pow
 
    def **(power: T): Double = pow(value.toDouble(), power.toDouble())
}

指数函数的和的对数logsumexp

在讲解logsumexp函数之前我们先要了解下这个函数是用来做什么的。假设我们有N个实数 Scipy数据函数的Scala实现 ,我们想要求如下公式:

  Scipy数据函数的Scala实现

如果很大或很小,直接计算可能会上溢出或下溢出,从而导致严重问题。举个例子,对于[0 1 0],直接计算是可行的,我们可以得到1.55。但对于[1000 1001 1000],却并不可行,我们会得到inf;对于[-1000,-999,-1000],还是不行,我们会得到-inf。导致此问题的原因是因为浮点数只有64位,在计算指数函数的环节exp{1000}会发生上溢出,计算exp(-1000)时会发生下溢出。即便在数学世界上式的值显然不是无穷大,但在计算机的浮点数世界里就是求不出来。解决方案很简单:

  Scipy数据函数的Scala实现

对任意a都成立,这意味着我们可以自由地调节指数函数的指数部分,一个典型的做法是取 Scipy数据函数的Scala实现 中的最大值。Python实现:

def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    """Compute the log of the sum of exponentials of input elements.
 
    Parameters
    ----------
    a : array_like
        Input array.
    axis : None or int or tuple of ints, optional
        Axis or axes over which the sum is taken. By default `axis` is None,
        and all elements are summed.
 
        .. versionadded:: 0.11.0
    keepdims : bool, optional
        If this is set to True, the axes which are reduced are left in the
        result as dimensions with size one. With this option, the result
        will broadcast correctly against the original array.
 
        .. versionadded:: 0.15.0
    b : array-like, optional
        Scaling factor for exp(`a`) must be of the same shape as `a` or
        broadcastable to `a`. These values may be negative in order to
        implement subtraction.
 
        .. versionadded:: 0.12.0
    return_sign : bool, optional
        If this is set to True, the result will be a pair containing sign
        information; if False, results that are negative will be returned
        as NaN. Default is False (no sign information).
 
        .. versionadded:: 0.16.0
 
    Returns
    -------
    res : ndarray
        The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
        more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
        is returned.
    sgn : ndarray
        If return_sign is True, this will be an array of floating-point
        numbers matching res and +1, 0, or -1 depending on the sign
        of the result. If False, only one result is returned.
 
    See Also
    --------
    numpy.logaddexp, numpy.logaddexp2
 
    Notes
    -----
    Numpy has a logaddexp function which is very similar to `logsumexp`, but
    only handles two arguments. `logaddexp.reduce` is similar to this
    function, but may be less stable.
 
    Examples
    --------
    >>> from scipy.special import logsumexp
    >>> a = np.arange(10)
    >>> np.log(np.sum(np.exp(a)))
    9.4586297444267107
    >>> logsumexp(a)
    9.4586297444267107
 
    With weights
 
    >>> a = np.arange(10)
    >>> b = np.arange(10, 0, -1)
    >>> logsumexp(a, b=b)
    9.9170178533034665
    >>> np.log(np.sum(b*np.exp(a)))
    9.9170178533034647
 
    Returning a sign flag
 
    >>> logsumexp([1,2],b=[1,-1],return_sign=True)
    (1.5413248546129181, -1.0)
 
    Notice that `logsumexp` does not directly support masked arrays. To use it
    on a masked array, convert the mask into zero weights:
 
    >>> a = np.ma.array([np.log(2), 2, np.log(3)],
    ...                  mask=[False, True, False])
    >>> b = (~a.mask).astype(int)
    >>> logsumexp(a.data, b=b), np.log(5)
    1.6094379124341005, 1.6094379124341005
 
    """
    a = _asarray_validated(a, check_finite=False)
    if b is not None:
        a, b = np.broadcast_arrays(a, b)
        if np.any(b == 0):
            a = a + 0.  # promote to at least float
            a[b == 0] = -np.inf
 
    a_max = np.amax(a, axis=axis, keepdims=True)
 
    if a_max.ndim > 0:
        a_max[~np.isfinite(a_max)] = 0
    elif not np.isfinite(a_max):
        a_max = 0
 
    if b is not None:
        b = np.asarray(b)
        tmp = b * np.exp(a - a_max)
    else:
        tmp = np.exp(a - a_max)
 
    # suppress warnings about log of zero
    with np.errstate(divide='ignore'):
        s = np.sum(tmp, axis=axis, keepdims=keepdims)
        if return_sign:
            sgn = np.sign(s)
            s *= sgn  # /= makes more sense but we need zero -> zero
        out = np.log(s)
 
    if not keepdims:
        a_max = np.squeeze(a_max, axis=axis)
    out += a_max
 
    if return_sign:
        return out, sgn
    else:
        return out

由于逻辑比较简单,直接翻译为Scala即可,这里不做详细解释了。


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

SQL进阶教程

SQL进阶教程

[ 日] MICK / 吴炎昌 / 人民邮电出版社 / 2017-11 / 79.00元

本书是《SQL基础教程》作者MICK为志在向中级进阶的数据库工程师编写的一本SQL技能提升指南。全书可分为两部分,第一部分介绍了SQL语言不同寻常的使用技巧,带领读者从SQL常见技术,比如CASE表达式、自连接、HAVING子句、外连接、关联子查询、EXISTS……去探索新发现。这部分不仅穿插讲解了这些技巧背后的逻辑和相关知识,而且辅以丰富的示例程序,旨在帮助读者提升编程水平;第二部分着重介绍关系......一起来看看 《SQL进阶教程》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具