PyTorch自动求导(Autograd)原理解析

栏目: Python · 发布时间: 4年前

内容简介:我们知道,深度学习最核心的其中一个步骤,就是求导:根据函数(linear + activation function)求weights相对于loss的导数(还是loss相对于weights的导数?)。然后根据得出的导数,相应的修改weights,让loss最小化。 各大深度学习框架Tensorflow,Keras,PyTorch都自带有自动求导功能,不需要我们手动算。 在初步学习PyTorch的时候,看到PyTorch的自动求导过程时,感觉非常的别扭和不直观。我下面举个例子,大家自己感受一下。这里让人感觉别

我们知道,深度学习最核心的其中一个步骤,就是求导:根据函数(linear + activation function)求weights相对于loss的导数(还是loss相对于weights的导数?)。然后根据得出的导数,相应的修改weights,让loss最小化。 各大深度学习框架Tensorflow,Keras,PyTorch都自带有自动求导功能,不需要我们手动算。 在初步学习PyTorch的时候,看到PyTorch的自动求导过程时,感觉非常的别扭和不直观。我下面举个例子,大家自己感受一下。

>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward() # 执行求导
>>> a.grad  # a.grad 即导数 d(e)/d(a) 的值
tensor(4.)

这里让人感觉别扭的是,调用 e.backward() 执行求导,为什么会更新 a 对象的状态 grad ?对于习惯了OOP的人来说,这是非常不直观的。因为,在OOP里面,你要改变一个对象的状态,一般的做法是,引用这个对象本身,给它的property显示的赋值(比如 user.age = 18 ),或者是调用这个对象的方法( user.setAge(18) ),让它状态得以改变。 而这里的做法是,调用了一个跟它( a )本身看起来没什么关系的对象( e )的方法,结果改变了它的状态。 每次写代码写到这个地方的时候,我都觉得心里一惊。因此,就一直想一探究竟,看看这内部的关联究竟是怎么样的。 根据上面的代码,我们知道的是, e 的结果,是由 cd 运算得到的,而 c ,又是根据 ab 相加得到的。现在,执行 e 的方法,最终改变了 a 的状态。因此,我们可以猜测 e 内部可能有某个东西,引用着 c ,然后呢, c 内部又有些东西,引用着 a 。因此,在运行 ebackward() 方法时,通过这些引用,先是改变 c ,在根据 c 内部的引用,最终改变了 a 。如果我们的猜测没错的话,那么这些引用关系到底是什么呢?在代码里是怎么提现的呢? 想要知道其中原理,最先想到的办法,自然是去看源代码。 遗憾的是, backward() 的实现主要是在C/Cpp层间做的,在 Python 层面做的事情很少,基本上就是对参数做了一下处理,然后调用native层面的实现。如下:

def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
    r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
    ...more comment
    """
    if grad_variables is not None:
        warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
        if grad_tensors is None:
            grad_tensors = grad_variables
        else:
            raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                               "arguments both passed to backward(). Please only "
                               "use 'grad_tensors'.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

    if grad_tensors is None:
        grad_tensors = [None] * len(tensors)
    elif isinstance(grad_tensors, torch.Tensor):
        grad_tensors = [grad_tensors]
    else:
        grad_tensors = list(grad_tensors)

    grad_tensors = _make_grads(tensors, grad_tensors)
    if retain_graph is None:
        retain_graph = create_graph

    Variable._execution_engine.run_backward(
        tensors, grad_tensors, retain_graph, create_graph,
        allow_unreachable=True)  # allow_unreachable flag

说到Cpp。。。 PyTorch自动求导(Autograd)原理解析 由于C/Cpp是我的知识盲区,只能通过一顿自行的探索操作,来了解这个执行过程了。

我们先看看 e 里面有什么。 由于 e 是一个 Tensor 变量,我们自然想到去看 Tensor 这个类的代码,看看里面有哪些成员变量。不幸的是,Python语言声明成员变量的方式跟 Java 这些静态语言不一样,他们是用到的时候直接用 self.xxx 随时声明的。不像Java这样,在某一个地方统一声明并做初始化。 当然,我们可以用正则表达式 self\.\w+\s+= 搜索所有类似于 self.xxx =  的地方,于是你会找到一些 data , requires_grad , _backward_hooks , retain_grad 等等。根据已有的知识,这些看起来都不像。看来相关的成员变量应该在其父类 TensorBase 里面。不幸的是, TensorBase 是用C/Cpp 实现的。这。。。这就又涉及到我的知识盲区了。。。

不过,Python其实还提供了其他的一些方式,来方便我们查看这个对象的属性和状态。那就是 vars() 方法和 dir() 方法。然而。。。

>>> vars(a)
{}
>>>
>>>
>>>
>>> dir(a)
['__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ilshift__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pow__', '__radd__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rfloordiv__', '__rmul__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_backward_hooks', '_base', '_cdata', '_coalesced_', '_dimI', '_dimV', '_grad', '_grad_fn', '_indices', '_make_subclass', '_nnz', '_values', '_version', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addbmm', 'addbmm_', 'addcdiv', 'addcdiv_', 'addcmul', 'addcmul_', 'addmm', 'addmm_', 'addmv', 'addmv_', 'addr', 'addr_', 'all', 'allclose', 'any', 'apply_', 'argmax', 'argmin', 'argsort', 'as_strided', 'as_strided_', 'asin', 'asin_', 'atan', 'atan2', 'atan2_', 'atan_', 'backward', 'baddbmm', 'baddbmm_', 'bernoulli', 'bernoulli_', 'bincount', 'bmm', 'btrifact', 'btrifact_with_info', 'btrisolve', 'byte', 'cauchy_', 'ceil', 'ceil_', 'char', 'cholesky', 'chunk', 'clamp', 'clamp_', 'clamp_max', 'clamp_max_', 'clamp_min', 'clamp_min_', 'clone', 'coalesce', 'contiguous', 'copy_', 'cos', 'cos_', 'cosh', 'cosh_', 'cpu', 'cross', 'cuda', 'cumprod', 'cumsum', 'data', 'data_ptr', 'dense_dim', 'det', 'detach', 'detach_', 'device', 'diag', 'diag_embed', 'diagflat', 'diagonal', 'digamma', 'digamma_', 'dim', 'dist', 'div', 'div_', 'dot', 'double', 'dtype', 'eig', 'element_size', 'eq', 'eq_', 'equal', 'erf', 'erf_', 'erfc', 'erfc_', 'erfinv', 'erfinv_', 'exp', 'exp_', 'expand', 'expand_as', 'expm1', 'expm1_', 'exponential_', 'fft', 'fill_', 'flatten', 'flip', 'float', 'floor', 'floor_', 'fmod', 'fmod_', 'frac', 'frac_', 'gather', 'ge', 'ge_', 'gels', 'geometric_', 'geqrf', 'ger', 'gesv', 'get_device', 'grad', 'grad_fn', 'gt', 'gt_', 'half', 'hardshrink', 'histc', 'ifft', 'index_add', 'index_add_', 'index_copy', 'index_copy_', 'index_fill', 'index_fill_', 'index_put', 'index_put_', 'index_select', 'indices', 'int', 'inverse', 'irfft', 'is_coalesced', 'is_complex', 'is_contiguous', 'is_cuda', 'is_distributed', 'is_floating_point', 'is_leaf', 'is_nonzero', 'is_pinned', 'is_same_size', 'is_set_to', 'is_shared', 'is_signed', 'is_sparse', 'isclose', 'item', 'kthvalue', 'layout', 'le', 'le_', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'log', 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_normal_', 'log_softmax', 'logdet', 'logsumexp', 'long', 'lt', 'lt_', 'map2_', 'map_', 'masked_fill', 'masked_fill_', 'masked_scatter', 'masked_scatter_', 'masked_select', 'matmul', 'matrix_power', 'max', 'mean', 'median', 'min', 'mm', 'mode', 'mul', 'mul_', 'multinomial', 'mv', 'mvlgamma', 'mvlgamma_', 'name', 'narrow', 'narrow_copy', 'ndimension', 'ne', 'ne_', 'neg', 'neg_', 'nelement', 'new', 'new_empty', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'nonzero', 'norm', 'normal_', 'numel', 'numpy', 'orgqr', 'ormqr', 'output_nr', 'permute', 'pin_memory', 'pinverse', 'polygamma', 'polygamma_', 'potrf', 'potri', 'potrs', 'pow', 'pow_', 'prelu', 'prod', 'pstrf', 'put_', 'qr', 'random_', 'reciprocal', 'reciprocal_', 'record_stream', 'register_hook', 'reinforce', 'relu', 'relu_', 'remainder', 'remainder_', 'renorm', 'renorm_', 'repeat', 'requires_grad', 'requires_grad_', 'reshape', 'reshape_as', 'resize', 'resize_', 'resize_as', 'resize_as_', 'retain_grad', 'rfft', 'roll', 'rot90', 'round', 'round_', 'rsqrt', 'rsqrt_', 'scatter', 'scatter_', 'scatter_add', 'scatter_add_', 'select', 'set_', 'shape', 'share_memory_', 'short', 'sigmoid', 'sigmoid_', 'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'slogdet', 'smm', 'softmax', 'sort', 'sparse_dim', 'sparse_mask', 'sparse_resize_', 'sparse_resize_and_clear_', 'split', 'split_with_sizes', 'sqrt', 'sqrt_', 'squeeze', 'squeeze_', 'sspaddmm', 'std', 'stft', 'storage', 'storage_offset', 'storage_type', 'stride', 'sub', 'sub_', 'sum', 'svd', 'symeig', 't', 't_', 'take', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dense', 'to_sparse', 'tolist', 'topk', 'trace', 'transpose', 'transpose_', 'tril', 'tril_', 'triu', 'triu_', 'trtrs', 'trunc', 'trunc_', 'type', 'type_as', 'unbind', 'unfold', 'uniform_', 'unique', 'unsqueeze', 'unsqueeze_', 'values', 'var', 'view', 'view_as', 'where', 'zero_']
>>>

可以看到,使用 vars() 方法,返回的集合是空的。而使用 dir() ,返回的却又太多了,你都不知道哪些是有用的哪些是没用的,哪些又是我们真正关心的。 怎么办呢? 看来只能Google了。经过一顿调查和连猜带蒙,我得出了一些结论。也不知道是否正确(准确),如果有错误或不准确的地方,还希望有大神不吝指出。

为了解释他们之间的关系,我们先从一个最简单的例子开始。

>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>>
>>> c.backward()
>>> a.grad
tensor(1.)
>>> b.grad
tensor(1.)
>>>

我们的问题是, ca 是怎么串联起来的?为什么执行 c.backward() ,会更新 a 的状态( a.grad 的值)? 其实,我们要找的东西,远在天边,近在眼前。

>>> c
tensor(5., grad_fn=<AddBackward0>)
>>>

可以看到,c里面有一个 gran_fn 变量。这个东西是什么呢?

>>> c.grad_fn
<AddBackward0 object at 0x10e56d160>
>>> type(c.grad_fn)
<class 'AddBackward0'>
>>>

可见,这是一个 AddBackward0 这个类的对象。遗憾的是,这个类也是用Cpp来写的。不过,这不代表我们不能在Python层做一些简单的探索,看看里面有些什么东西。

>>> dir(c.grad_fn)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']

除去那些特殊方法(以 __ 开头和结束的)和私有方法(以 _ 开头的),范围缩小到 ['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad’] 这其中,看名字,最可疑的是这个 next_functions 。我们看看是什么:

>>> c.grad_fn.next_functions
((<AccumulateGrad object at 0x10e56d160>, 0), (<AccumulateGrad object at 0x1205b29b0>, 0))
>>>

看起来,这个 next_functions 是一个tuple of tuple of AccumulateGrad and int 。 继续探索这个 AccumulateGrad

>>> ag = c.grad_fn.next_functions[0][0]
>>> dir(ag)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable']

同样的,去掉那些特殊函数。我们感兴趣的范围缩小到 ['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable'] 这其中,除了前面提高过的 'next_functions' 之外,我们惊讶的发现,还有一个 叫 variable 的属性。我们分别都看一下:

>>> ag.next_functions
()
>>> ag.variable
tensor(2., requires_grad=True)
>>>

可见, ag1variable 这个属性是一个

tensor(2., requires_grad=True)

这个看起来似乎跟我们前面定义的a是同一个啊。是吗?我们确认一下:

>>> id(a)
4842774104
>>> id(ag.variable)
4842774104
>>>

果然是! 到这里,谜底基本上就呼之欲出了。

当我们执行 c.backward() 的时候。这个操作将调用c里面的 grad_fn 这个属性,执行求导的操作。这个操作将遍历 grad_fnnext_functions ,然后分别取出里面的function( AccumulateGrad ),执行求导操作。计算出结果以后,将结果保存到他们对应的 variable 这个变量所引用的对象( ab )的 grad 这个属性里面。

于是,当我们执行完 c.backward() 之后, ab 里面的 grad 值就得到了更新。

再回到我们开篇提到的稍微复杂点的例子:

>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward()
>>> a.grad
tensor(4.)
>>> b.grad
tensor(4.)
>>> c.grad
>>> d.grad
tensor(5.)

以此类推, e 到各个节点 abcd 的关联也就很容易理解了。

>>> e
tensor(20., grad_fn=<MulBackward0>)
>>> e.grad_fn
<MulBackward0 object at 0x111cb5470>
>>> e.grad_fn.next_functions
((<AddBackward0 object at 0x110501438>, 0), (<AccumulateGrad object at 0x111cb5fd0>, 0))

分别把 next_functions 中的function取出来看看

>>> ((f1, _), (f2, _)) = e.grad_fn.next_functions

>>> f1
<AddBackward0 object at 0x111cb5fd0>
>>> f1.variable
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'AddBackward0' object has no attribute 'variable'

>>> c
tensor(5., grad_fn=<AddBackward0>)
>>> c.grad_fn
<AddBackward0 object at 0x111cb5fd0>

>>> f2
<AccumulateGrad object at 0x1103ee4e0>
>>> f2.variable
tensor(4., requires_grad=True)

可见, e.grad_fn.next_functions 中的第一个function f1 ,就是 c.grad_fn

不过,如果跟着刚刚的思路,你会觉得意外的是, f1 是没有 variable 变量的。这是因为, c 的结果,是由 ab 相加的出来的,这样的变量是非“叶变量”。 如果我们把 abcde 和他们之间的运算过程理解为一棵树。那么, abd 都是我们自己“new”出来的,这样的节点叫叶节点。这些叶节点分别有一个 AccumulateGrad 类型的function跟它们对应起来。则像c、e这些,不是我们自己直接创建的,而是通过一些运算得出的,就是非叶节点。对于非叶节点来说,默认情况下他们不需要存储导数值(当然,如果需要,也是有办法做到的)。因此,他们的 grad_fn ,不需要有一个变量 variable 引用着他们。

e.backward() 执行求导时,系统遍历 e.grad_fn.next_functions ,分别执行求导。如果 e.grad_fn.next_functions 中有哪个是 AccumulateGrad ,则把结果保存到 AccumulateGrad 的variable引用的变量中。否则,递归遍历这个function的 next_functions ,执行求导过程。最终到达所有的叶节点,求导结束。同时,所有的叶节点的 grad 变量都得到了相应的更新。 他们之间的关系如下图所示: PyTorch自动求导(Autograd)原理解析

那么,还有两个问题没有解决: 1. 这些各种function,像 AccumulateGradAddBackward0MulBackward0 ,是怎么产生的? 2. 这些function,比如上面出现过的 AddBackward0MulBackward0 ,具体是怎么求导的呢?

对于第一个问题,很自然的猜测,是PyTorch重写了一些操作符,像 +* 等。在这个过程中,创建了这些function,并建立起了引用关系。 对于第二个问题,简单的说,就是在每个函数定义的时候,都需要自己定义好 forward()backward() 函数。在 forward() 里面实现这个运算的执行过程。比如,相加、相乘,在 backward() 则实现这个运算的求导过程。

以上就是我对PyTorch的自动求导原理的理解。只是一个大概的,比较浅显的理解。对于一些更加细节的,包括一些特殊情况的处理,推荐大家看这个视频。讲得非常清楚。

https://www.youtube.com/watch?v=MswxJw-8PvE

参考: https://pytorch.org/docs/stable/autograd.html#in-place-operations-on-tensors https://pytorch.org/docs/stable/notes/extending.html https://www.youtube.com/watch?v=MswxJw-8PvE


以上所述就是小编给大家介绍的《PyTorch自动求导(Autograd)原理解析》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

安全测试指南(第4版)

安全测试指南(第4版)

OWASP基金会 / 电子工业出版社 / 2016-7-1 / CNY 89.00

软件安全问题也许是这个时代面临的*为重要的技术挑战。Web应用程序让业务、社交等网络活动飞速发展,这同时也加剧了它们对软件安全的要求。我们急需建立一个强大的方法来编写和保护我们的互联网、Web应用程序和数据,并基于工程和科学的原则,用一致的、可重复的和定义的方法来测试软件安全问题。本书正是实现这个目标的重要一步,作为一本安全测试指南,详细讲解了Web应用测试的“4W1H”,即“什么是测试”、“为什......一起来看看 《安全测试指南(第4版)》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

html转js在线工具
html转js在线工具

html转js在线工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具