跳转至

自动微分autograd原理-简单demo

需求

实现一个类似pytorch的简略版自动微分mini框架,支持加减乘除,api设计参考pytorch即可。这个实际上是某节的一道面试题的精简版,不知道今年还出不出。

分析

  1. 需要构造计算图,以z = func(x,y) 为例,执行z.backward时候需要找到x和y,因此构造的z对象对应的类需要有存储x和y对象引用的设计
  2. 函数指针。需要构造前向传播函数与反向传播函数,显然每个函数需要当做参数传递给参与计算的变量使用
  3. 接口设计参考torch.autograd.Function

代码

import abc
from copy import copy
class Function(metaclass=abc.ABCMeta):
    def __init__(self) -> None:
        self._buffer = list()
    @classmethod
    @abc.abstractmethod
    def forward(cls, *args, **kwargs):
        pass
    @classmethod
    @abc.abstractmethod
    def backward(cls, *args, **kwargs):
        pass
    def saved_for_backward(self, *args):
        self._buffer.append(args)
    def get_saved_buffer(self):
        return self._buffer.pop(-1)
    @classmethod
    def apply(cls, *args, **kwargs):
        obj = cls()
        result = obj.forward(*args, **kwargs)
        # add grad func for args
        if result.required_grad:
            result.saved_ctx = (args, obj.backward)
        return result
def as_var(data):
    if isinstance(data, Varible):
        return data
    return Varible(float(data), required_grad=False)
class Varible(object):
    def __init__(self, value, required_grad = False) -> None:
        self.required_grad = required_grad
        self.grad = None
        self.data = value
        self.saved_ctx = None
    def backward(self, grad = None, retain_graph = False):
        grad = grad or 1.0
        if not self.required_grad:
            return
        if self.saved_ctx is not None:
            vars, backward_func = self.saved_ctx
            grads = backward_func(grad)
            for var, var_grad in zip(vars, grads):
                if var.required_grad:
                    var.grad = var.grad or 0.0
                    var.grad += var_grad
                    var.backward(var_grad, retain_graph)
            if not retain_graph:
                self.saved_ctx = None
    def copy_(self, other):
        self.data = copy(other.data)
        self.required_grad = copy(other.required_grad)
        self.grad = copy(other.grad)
        self.saved_ctx = copy(other.saved_ctx)
    def __iadd__(self, other):
        ret = self.__add__(other)
        self.copy_(ret)
    def __isub__(self, other):
        ret = self.__sub__(other)
        self.copy_(ret)
    def __imul__(self, other):
        ret = self.__mul__(other)
        self.copy_(ret)
    def __idiv__(self, other):
        ret = self.__div__(other)
        self.copy_(ret)
    def __radd__(self, other):
        return self.__add__(other)
    def __rsub__(self, other):
        return self.__sub__(other)
    def __rmul__(self, other):
        return self.__mul__(other)
    def __rtruediv__(self, other):
        return self.__truediv__(other)
    def __add__(self, other):
        return AddFunction.apply(self, as_var(other))
    def __sub__(self, other):
        return SubFunction.apply(self, as_var(other))
    def __mul__(self, other):
        return MulFunction.apply(self, as_var(other))
    def __truediv__(self, other):
        return DivFunction.apply(self, as_var(other))
    def __str__(self) -> str:
        return f"data={self.data}, required_grad={self.required_grad}, grad={self.grad}"
    def __repr__(self) -> str:
        return self.__str__()
class AddFunction(Function):
    def forward(ctx, x, y):
        required_grad = x.required_grad or y.required_grad
        result_data = x.data + y.data
        result = Varible(result_data, required_grad=required_grad)
        return result
    def backward(ctx, grad):
        # grad for x and y
        return grad, grad
class SubFunction(Function):
    def forward(ctx, x, y):
        required_grad = x.required_grad or y.required_grad
        result_data = x.data - y.data

        result = Varible(result_data, required_grad=required_grad)
        return result
    def backward(ctx, grad):
        # grad for x and y
        return grad, -grad
class MulFunction(Function):
    def forward(ctx, x, y):
        required_grad = x.required_grad or y.required_grad
        result_data = x.data * y.data
        ctx.saved_for_backward(x.data, y.data)
        result = Varible(result_data, required_grad=required_grad)
        return result
    def backward(ctx, grad):
        # grad for x and y
        x, y = ctx.get_saved_buffer()
        return grad * y, grad * x
class DivFunction(Function):
    def forward(ctx, x, y):
        required_grad = x.required_grad or y.required_grad
        result_data = x.data / y.data
        ctx.saved_for_backward(x.data, y.data)
        result = Varible(result_data, required_grad=required_grad)
        return result
    def backward(ctx, grad):
        # grad for x and y
        x, y = ctx.get_saved_buffer()
        return grad / y, - grad * x * (y ** -2)

if __name__ == "__main__":
    a = Varible(1.0, required_grad=True)
    b = Varible(2.0, required_grad=True)
    y = 2 * a / b + 3 * a * a
    y.backward()
    print(y)
    print(b)
    print(a)

总结

以上代码既实现了标量的自动微分,主要是两个类,FunctionVariable。核心是一个方法backward,里头会链式调用子节点的backward方法,实现链式的反向传播功能。


最后更新: March 21, 2024
创建日期: March 21, 2024