自动微分

前言

自动微分是一种对程序中的函数计算导数的技术。相比符号微分,其避免了表达式膨胀的性能问题;而相比数值微分,其解决了近似误差的正确性问题。

关于自动微分技术本身的相关背景知识,我们将不在本手册中进行详细介绍。推荐用户阅读以下参考文献进行详细了解。

在仓颉编程语言中,自动微分将作为原生语言特性被提供给用户,用户可以通过配置编译选项 --enable-ad 来开启自动微分特性。

cjc --enable-ad helloworld.cj

需要说明的是,目前仓颉自动微分特性仅支持反向模式自动微分。

可微类型

在仓颉自动微分的库中,我们提供了Differentiable interface 用于定义仓颉中数据类型的微分规则。所有参与自动微分中导数计算的数据类型均应实现该 interface 以确保被自动微分系统认为是合法的可微类型

// Defined in the AD library
public interface Differentiable<T> where T <: Differentiable<T> {
    // Initialize a zero tangent
    func TangentZero(): T
    // Sum up the tangent value `y`
    func TangentAdd(y: T): T
}

对于可微类型(例如浮点数类型、元组类型、struct类型),编译器会自动生成合理的 interface 实例,对于可微struct类型,用户也可以通过手动实现该 interface 实例。

可微数值类型

仓颉中的可微数值类型包括三种:

  • Float16
  • Float32
  • Float64

可微元组类型

当元组中所有元素均为可微类型时,则该元组是可微元组类型。

let a: (Float64, Float64) = (1.0, 1.0)   // differentiable
let b: (Float64, String) = (1.0, "foo")  // NOT differentiable

可微 struct 类型

默认情况下,struct 类型不可微,当 struct 类型不包含静态成员变量时,用户可在 struct 类型定义上方增加 @Differentiable 标注,将其定义为可微并通过 except 列表配置该 struct 类型的微分行为。

给定 except 列表后,该 struct 类型的成员变量将分为两类。

  • 不在 except 列表中的成员变量,必须为不可变变量,并且其类型为可微类型,成员变量将参与该 struct 类型对象的微分过程
  • except 列表中的成员变量将不参与该 struct 类型对象的微分过程。在微分过程中,该成员变量将被保持为未初始化状态。因此用户应确保不访问微分结果中的这些成员变量值,否则将导致未定义行为
// Differentiable
@Differentiable
struct Point {
    let x: Float64
    let y: Float64
    init(x: Float64, y: Float64) {
        this.x = x
        this.y = y
    }
}

// Differentiable, but no back-propagation will happen in the excepted field `tag`
@Differentiable [except: [tag]]
struct TaggedPoint {
    let x: Float64
    let y: Float64
    let tag: String
    init(x: Float64, y: Float64, tag: String) {
        this.x = x
        this.y = y
        this.tag = tag
    }
}

// Compilation Error: variable `tag` has non-differentiable type String,
// but it does not appear in an except list
@Differentiable
struct TaggedWrong {
    let x: Float64
    let y: Float64
    let tag: String
    init(x: Float64, y: Float64, tag: String) {
        this.x = x
        this.y = y
        this.tag = tag
    }
}

我们也提供了定义 except 列表的另外一个语法糖版本,用户可以通过以下语法定义 struct 类型的 include 列表。此时,struct 所有不在该 include 列表中的成员变量都被定义为在 except 列表中。

// Differentiable, but no back-propagation will happen in the excepted field `tag`
@Differentiable [include: [x, y]]
struct TaggedPoint {
    let x: Float64
    let y: Float64
    let tag: String
    init(x: Float64, y: Float64, tag: String) {
        this.x = x
        this.y = y
        this.tag = tag
    }
}

可微 unit 类型

Unit 类型是一种特殊的可微类型。对任何 Unit 类型对象的微分操作所得结果均仍是 Unit 类型。

不可微类型

仓颉自动微分暂不支持对 StringArrayenumClassInt16Int32Int64Interface 类型数据的微分,故这些类型均为不可微类型。

可微函数

默认情况下,函数均为不可微,但用户可以在函数定义上方增加 @Differentiable 标注,将其定义为可微并通过 except 列表配置该函数的微分行为。

给定 except 列表后,该函数的参数将分为两类。

  • 不在 except 列表中的参数,必须为可微类型。该参数将参与函数的微分过程,微分结果中将包含函数相对该参数的导数
  • except 列表中的参数将不参与函数的微分过程。在微分过程中,该类参数及依赖该参数的中间变量均将被忽略,微分结果中将不包含函数相对该参数的导数
// Differentiable function, its derivatives will be calculated with respect to `x` and `y`
@Differentiable
func f(x: Float64, y: Float64) {
    return x * y
}

// Differentiable function, its derivative will only be calculated with respect to `x`
@Differentiable [except: [y]]
func f(x: Float64, y: Float64) {
    return x * y
}

// Compilation Error: input `z` has non-differentiable type String,
// but it does not appear in an except list
@Differentiable
func f(x: Float64, y: Float64, z: String) {
    return x + y
}

相似地,我们也提供了定义 except 列表的另外一个语法糖版本,用户可以通过以下语法定义函数的 include 列表。此时,函数所有不在该 include 列表中的参数都被定义为在 except 列表中。exceptinclude 只能出现一个。

// Differentiable function, its derivative will only be calculated with respect to `x`
@Differentiable [include: [x]]
func f(x: Float64, y: Float64) {
    return x * y
}

在使用上述语法进行可微函数定义时,用户还需要确保函数满足以下条件。

  • 函数的返回类型为可微类型
  • 函数参数为可微类型,且不在 except 列表中(或在 include 列表中)
  • 函数中未使用全局变量
  • 函数中未使用静态变量
  • 函数体中所有表达式均为可微函数合法表达式,即表达式需满足以下任一条件
    • 表达式可微,即其微分规则已知。目前仓颉支持的可微表达式如下:
      • 赋值表达式
      • 算术表达式
      • 流表达式
      • 条件表达式(if
      • 循环表达式(仅 while, do-while),并且不支持使用 continuebreakreturn
      • lambda 表达式
      • 可微 Tuple 类型对象的初始化表达式
      • 可微 Tuple 类型对象的解构和下标访问表达式
      • 对可微函数的函数调用表达式,且在编译期可以确定哪一个可微函数定义被调用
    • 表达式不直接或间接地与不在 except 列表中的函数参数(包括嵌套函数的外层作用域函数参数)有数据依赖关系。若必须产生数据依赖关系,则可以使用 [stopGradient 函数接口] 来中止梯度传播
    • 可微函数(嵌套函数或匿名函数)中 return 表达式(包括自动插入的 return 表达式)只能出现一次
    • 嵌套函数或匿名函数未捕获可变变量
// Not differentiable
func f1(x: Float64, y: Float64): Float64 {
    return x + y
}

// Compilation Error: expression x is marked for back-propagation,
// but is used in a context that cannot be back-propagated
@Differentiable
func f2(x: Float64, y: Float64): Float64 {
    return f1(x, y)
}
// Compilation Error: match expression is not supported in differentiable function
@Differentiable
func f3(x: Float64) {
    var a = match (x) {
        case _ => 1.0
    }
    return a
}
@Differentiable
func f4 (x: Float64) {
   var y = 1.0
   // Compilation Error: capture mutable variable y in nested function is not supported by AD yet
   func nested (x: Float64) {
       y = y + x
   }
   nested(x)
}
@Differentiable [except: [n]]
func g(m: Float64, n: Float64) {
    return m + n
}

@Differentiable
func f5(x: Float64, y: Float64) {
    // Use `stopGradient` to stop the gradient propagation so we can
    // pass the `y` in `f5` to `n` in `g` even `n` is marked as `except`
    return g(x, stopGradient<Float64>(y))
}
func g(m: Float64) {
    return m == 1.0
}

@Differentiable
func f6(x: Float64) {
    // Use `stopGradient` to stop the gradient propagation so we can
    // pass the `x` in `f6` to non-differentiable function `g`
    let cond = g(stopGradient<Float64>(x))
    if (cond) {
        x * 2.0
    } else {
        x * 3.0
    }
}

注意:为了方便计算 lambda 表达式捕获变量的导数,我们在仓颉自动微分的库中预定义了以下两个接口,这两个接口将被仓颉的自动微分实现模块调用,不建议用户直接使用。

/* lambda 表达式捕获变量的导数组成的环境变量类型的基类 */
public abstract class EnvironmentTangent {
    /* 合并两部分环境变量导数 */
    public func TangentAdd(other: EnvironmentTangent): EnvironmentTangent
}
/* 环境变量类型对应的 0 梯度类型 */
public class EnvironmentZero <: EnvironmentTangent {
    public func TangentAdd(other: EnvironmentTangent): EnvironmentTangent {
        return other
    }
}

自定义函数微分规则

用户还可以通过手动提供伴随函数来为可微函数自定义微分规则。给定一个可微的全局函数 f (称为:源函数),用户可通过在另一个函数 g 的函数定义上方添加 @Adjoint 标注,将其指定为 f 的自定义伴随函数。

伴随函数需满足以下条件:

  • 伴随函数的输入参数的数量、类型和顺序与原可微函数完全一致
  • 伴随函数的输出是一个包含两个子元素的元组
    • 元组第一个子元素为原可微函数在当前伴随函数的参数输入下的输出结果
    • 元组第二个子元素为一个函数,作为原可微函数的梯度反向传播器
      • 梯度反向传播器的输入类型与原可微函数的输出数量、类型和顺序一致
      • 梯度反向传播器的输出类型与原可微函数的输入数量、类型和顺序一致
// Function `f` is differentiable and its custom adjoint function is `g`
@Differentiable
func f(x: Float64): Float64 {
    return x * x * x
}

@Adjoint [primal: f]
func g(x: Float64): ((Float64), ((Float64) -> Float64)) {
    let xSquare = x * x
    return (
        xSquare * x,
        { dy: Float64 =>
            return dy * xSquare * 3.0
        }
    )
}

当用户通过上述语法为源函数自定义伴随函数后,自动微分系统将直接使用该自定义伴随函数进行微分求导。需要注意的是我们当前暂不支持用户对参数或返回值包含函数类型的可微函数自定义伴随函数。

非全局可微函数

需要注意的是,上述可微函数规则默认应用于全局函数。事实上,仓颉编程语言中存在多种非全局函数。在自动微分系统中,我们支持以下非全局函数被定义为可微函数,其他类型的非全局函数均不支持定义为可微函数。

struct 构造函数和成员函数

用户可以使用相同的可微函数标注 @Differentiablestruct 中的构造函数和成员函数定义为可微。

  • struct 构造函数标注为可微。定义构造函数为可微的前提是 struct 类型本身可微。在这种情况下,调用该构造函数的 struct 初始化表达式可微。相反地,若构造函数未被标注为可微,则调用该构造函数的 struct 初始化表达式不可微
  • struct 成员函数标注为可微。根据仓颉语言定义,struct 成员函数隐藏包含了一个标识符为 this 的函数参数用于表示 struct 对象本身,故用户也可以在 except 列表中配置 this 参数来决定是否需要对 this 进行微分。
@Differentiable
struct Point {
    let a: Float64
    let b: Float64
    // Differentiable constructor
    @Differentiable
    init(x: Float64) {
        a = x
        b = x
    }
}

@Differentiable
func foo(x: Float64) {
    // The initialization will call the differentiable constructor.
    // Therefore it is differentiable here.
    return Point(x)
}
@Differentiable
struct Point {
    let a: Float64
    let b: Float64
    init(x: Float64) {
        a = x
        b = x
    }

    // Differentiable method
    @Differentiable
    public func sum(bias: Float64): Float64 {
        return a + b + bias
    }
}

Class 成员函数

用户也可以使用相同的可微函数标注 @Differentiableclass 中的非open成员函数定义为可微。同样根据仓颉语言定义,class 成员函数隐藏包含了一个标识符为 this 的函数参数用于表示 class 对象本身。由于在仓颉中由于 class 类型本身不可微,故除静态成员函数外,用户必须将 this 加入到可微成员函数的 except 列表中。另外需要注意的是,返回类型为This的成员函数因返回类型不可微,所以无法被定义成可微。

class Point {
    let a: Float64
    let b: Float64
    init(x: Float64) {
        a = x
        b = x
    }

    // Function foo can not be defined as differentiable
    func foo(): This {
        return this
    }

    // Differentiable method. But user must add `this` in the `except` list here.
    @Differentiable [except: [this]]
    func sum(bias: Float64): Float64 {
        return a + b + bias
    }
}

拓展的成员函数

用户也可以使用相同的可微函数标注语法将扩展的成员函数定义为可微,支持直接扩展和接口扩展。

interface I {
    func goo(a: Float64, b: Float64): Float64
}

@Differentiable
struct Foo {}

// direct extensions
extend Foo {
    @Differentiable
    func foo(a: Float64, b: Float64) {
        return a + b
    }
}

// interface extensions
extend Foo <: I {
    @Differentiable
    public func goo(a: Float64, b: Float64) {
        return a + b
    }
}

如果需要扩展的类型不可微,用户必须将this加入到扩展的可微成员函数的except列表中。

interface I {
    func goo(a: Float64, b: Float64): Float64
}

struct Foo {}

// direct extensions
extend Foo {
    @Differentiable[except: [this]]
    func foo(a: Float64, b: Float64) {
        return a + b
    }
}

// interface extensions
extend Foo <: I {
    @Differentiable[except: [this]]
    public func goo(a: Float64, b: Float64) {
        return a + b
    }
}

微分表达式

@Grad表达式

给定一个可微函数和相应的输入参数值,用户可使用 @Grad 表达式来获取该函数在该输入值处的梯度值。

@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
    return if (negate) { - x * y } else { x * y }
}

main(): Int64 {
    // Since `negate` is excepted, the gradient of product only has two components
    let productGrad = @Grad(product, 2.0, 3.0, true)
    print(productGrad[0].toString())        // Prints -3.000000
    print(productGrad[1].toString())        // Prints -2.000000
    return 0
}

使用@Grad表达式时有以下注意事项

  • @Grad表达式只能作为初始值用于 var 或 let 变量初始化表达式中
  • 给定的diffFunc标识符必须表示一个函数,且满足以下条件
    • 必须是全局函数
    • 必须是可微函数
    • 函数不能有命名参数和参数默认值
    • 函数的返回类型只能是一下类型之一:Float16Float32Float64
  • inputVal组成的集合必须与diffFunc表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)的函数调用匹配规则相同
  • 给定上述类型为(X1, X2, ..., Xm)->Y的函数diffFunc
    • 若其except列表为空,则@Grad表达式的类型为(X1, X2, ..., Xm)。若该Tuple类型元素数量为 0,则退化为Unit类型,若该Tuple类型仅包含一个元素Xj,则退化为Xj类型
    • 若其except列表不为空,则@Grad表达式的类型中需相应地剔除在except列表中的函数参数类型。假定except列表中包含的参数为Xj,则@Grad表达式的类型为(X1, X2, ..., Xj-1, Xj+1, Xm)
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
    return x * y
}

main() {
    let res = @Grad(foo, 2.0, 3.0)        // Ok
    // Compilation Error: @Grad expr must be used in var decl with an identifier
    print(@Grad(foo, 2.0, 3.0)[0].toString())
    // Compilation Error: @Grad expr must be used in var decl with an identifier
    let (gradx, grady) = @Grad(foo, 2.0, 3.0)
    return 0
}
@Differentiable
struct A {
    @Differentiable
    public func foo(x: Float64, y: Float64) {
        return x + y
    }
}

main() {
    let a = A()
    // Compilation Error: the function is not a top-level differentiable function identifier
    let temp = @Grad(a.foo, 1.0, 1.0)
    return 0
}

@ValWithGrad表达式

给定一个可微函数和相应的输入参数值,用户可使用 @ValWithGrad 表达式来获取该函数在该输入值处的结果和梯度值。

@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
    return if (negate) { - x * y } else { x * y }
}

main(): Int64 {
    let productValWithGrad = @ValWithGrad(product, 2.0, 3.0, true)
    let (productRes, productGrad) = productValWithGrad
    print(productRes.toString())        // Prints -6.000000
    print(productGrad[0].toString())    // Prints -3.000000
    print(productGrad[1].toString())    // Prints -2.000000
    return 0
}

使用@ValWithGrad表达式时有以下注意事项。

  • @ValWithGrad表达式只能作为初始值用于 var 或 let 变量初始化表达式中
  • 给定的diffFunc标识符必须表示一个函数,且满足以下条件
    • 必须是全局函数
    • 必须是可微函数
    • 函数不能有命名参数和参数默认值
    • 函数的返回类型只能是一下类型之一:Float16Float32Float64
  • inputVal组成的集合必须与diffFunc表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)的函数调用匹配规则相同
  • 给定上述类型为(X1, X2, ..., Xm)->Y的函数diffFunc
    • 若其except列表为空,则@ValWithGrad表达式的类型为(Y, (X1, X2, ..., Xm))。若该Tuple类型元素数量为 0,则退化为(Y, Unit)类型,若该Tuple类型仅包含一个元素Xj,则退化为(Y, Xj)类型
    • 若其except列表不为空,则@ValWithGrad表达式的类型中需相应地剔除在except列表中的函数参数类型。假定except列表中包含的参数为Xj,则@ValWithGrad表达式的类型为(Y, (X1, X2, ..., Xj-1, Xj+1, Xm))
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
    return x * y
}

main() {
    let res = @ValWithGrad(foo, 2.0, 3.0)        // Ok
    // Compilation Error: @ValWithGrad expr must be used in var decl with an identifier
    print(@ValWithGrad(foo, 2.0, 3.0)[0].toString())
    // Compilation Error: @ValWithGrad expr must be used in var decl with an identifier
    let (res, (gradx, grady)) = @ValWithGrad(foo, 2.0, 3.0)
    return 0
}
@Differentiable
struct A {
    @Differentiable
    public func foo(x: Float64, y: Float64) {
        return x + y
    }
}

main() {
    let a = A()
    // Compilation Error: the function is not a top-level differentiable function identifier
    let temp = @ValWithGrad(a.foo, 1.0, 1.0)
    return 0
}

@AdjointOf表达式

给定一个可微函数,用户还可以使用 @AdjointOf 表达式来获取对该函数微分产生的伴随函数。

  • @AdjointOf表达式只能作为初始值用于 var 或 let 变量初始化表达式中
  • 给定的diffFunc标识符必须表示一个函数,且满足以下条件
    • 必须是全局函数
    • 必须是可微函数
    • 函数不能有命名参数和参数默认值
  • 给定上述类型为(X1, X2, ..., Xm)->Y的函数diffFunc
    • 若其except列表为空,则@AdjointOf表达式的类型为(X1, X2, ..., Xm)->(Y, (Y) -> (X1, X2, ..., Xm))。若该Tuple类型元素数量为 0,则退化为(X1, X2, ..., Xm)->(Y, (Y) -> Unit)类型,若该Tuple类型仅包含一个元素Xj,则退化为(X1, X2, ..., Xm)->(Y, (Y) -> Xj)类型
    • 若其except列表不为空,则@AdjointOf表达式的类型中需相应地剔除在except列表中的函数参数类型。假定except列表中包含的参数为Xj,则@AdjointOf表达式的类型为(X1, X2, ..., Xm)->(Y, (Y) -> (X1, X2, ..., Xj-1, Xj+1, Xm))
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
    return x * y
}

main() {
    // Get the adjoint function of `foo`
    let fooAdj = @AdjointOf(foo)

    // Given the value of `x` as 2.0 and `y` as 3.0, the adjoint function
    // will return:
    //     1) the result of `foo` when `x = 2.0` and `y = 3.0`
    //     2) an back-propagator which propagates the gradient from output to input for `foo`
    let res = fooAdj(2.0, 3.0)

    let fooRes = res[0]       // Prints 6.000000
    let fooBP = res[1]        // The back-propagator
    let (dx, dy) = fooBP(1.0)
    print(dx.toString())      // Prints 3.000000
    print(dy.toString())      // Prints 2.000000
}

@VJP表达式

给定一个可微函数和相应的输入参数值,用户可使用@VJP表达式来获取该函数在该输入值处的结果和反向传播函数。

使用@VJP表达式时有以下注意事项。

  • 给定的diffFunc标识符必须表示一个函数,且满足以下条件
    • 必须是全局函数
    • 必须是可微函数
    • 函数不能有命名参数和参数默认值
  • inputVal组成的集合必须与diffFunc表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)的函数调用匹配规则相同
  • 给定上述类型为(X1, X2, ..., Xm)->Y的函数diffFunc
    • 若其except列表为空,则@VJP表达式的类型为(Y, (Y) -> (X1, X2, ..., Xm))。若该Tuple类型元素数量为 0,则退化为(Y, (Y) -> Unit)类型,若该Tuple类型仅包含一个元素Xj,则退化为(Y, (Y) -> Xj)类型
    • 若其except列表不为空,则@VJP表达式的类型中需相应地剔除在except列表中的函数参数类型。假定except列表中包含的参数为Xj,则@VJP表达式的类型为(Y, (Y) -> (X1, X2, ..., Xj-1, Xj+1, Xm))
@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
    return if (negate) { - x * y } else { x * y }
}

main(): Int64 {
    let productVJP = @VJP(product, 2.0, 3.0, true)
    let (productRes, productBP) = productVJP
    print(productRes.toString())        // Prints -6.000000
    let productGrad = productBP(1.0)
    print(productGrad[0].toString())    // Prints -3.000000
    print(productGrad[1].toString())    // Prints -2.000000
    return 0
}

stopGradient函数接口

在可微函数中,用户还可以使用 stopGradient 函数接口来强制中止某个变量或中间结果上的梯度传播。stopGradient 函数实现为一个泛型函数,可接受任意类型数据输入并将其直接返回。因此,该函数接口的使用不影响原可微函数的执行逻辑,但自动微分系统将识别该函数接口,并中止函数参数 x 对应变量或中间结果的梯度传播。此外,stopGradient 函数作用于函数时,可以把可微函数变成不可微函数。

public func stopGradient<T>(x: T) {
    return x
}
@Differentiable
func foo(x: Float64) {
    let t0 = x * 2.0
    let t1 = x * 3.0
    return t0 + t1                          // Both `t0` and `t1` will propagate gradient to `x`
}

@Differentiable
func goo(x: Float64) {
    let t0 = x * 2.0
    let t1 = x * 3.0
    return t0 + stopGradient<Float64>(t1)   // Only `t0` will propagate gradient to `x`
}

main() {
    let res0 = @Grad(foo, 1.0)      // `res0` will be 5.0
    let res1 = @Grad(goo, 1.0)      // `res1` will be 2.0
}

伴随函数的导入/导出

给定一个在包中定义的源函数,自动微分系统将对其微分并在该包中生成它的伴随函数。该伴随函数将具有和源函数相同的 public 属性。用户可以通过使用 import a.* 确保伴随函数与源函数拥有一致的导入/导出行为,即当用户导入源函数时其伴随函数将一并被导入,从而允许用户在当前包对该函数进行微分操作。

//================================= file A
package a

// AD system will generate `fooAdj` as the adjoint of `foo`
// `fooAdj` also has `public` attribute
@Differentiable
public func foo(x: Float64) {
    return x
}

//================================= file B
package b
// Will also import a.fooAdj implicitly
import a.*

main() {
    // The AD system will use the `a.fooAdj` as the adjoint of `a.foo` for differentiation
    let gradRes = @Grad(foo, 2.0)
    print(gradRes.toString()) // Prints 1.000000
    0
}

需要注意的是,若用户实现了自定义伴随函数,则也需要为该伴随函数手动配置 public 属性,从而确保该伴随函数也会被同步导入。除此之外,在导入/导出场景下,给定一个源函数,多个来自不同来源的伴随函数有可能同时出现。在这种情况下,我们定义了如下规则来确定不同版本伴随函数的优先级和使用规则。

  • 当前包内定义的本地伴随函数比从其他包导入的伴随函数优先级更高,后者将被前者屏蔽
  • 若出现多个从其他包导入的不同版本伴随函数且无本地伴随函数时,将触发编译器报错
//================================= file A
package a

@Differentiable
public func foo(x: Float64): Float64 {
    return x
}

@Adjoint [primal: foo]
public func dfoo(x: Float64): (Float64, (Float64)->Float64) {
    return (
        x,
        { dy: Float64 =>
            return 1.0 * dy
        }
    )
}

//================================= file B
import a.*

@Adjoint [primal: foo]
func localDFoo(x: Float64): (Float64, (Float64)->Float64) {
    return (
        x,
        { dy: Float64 =>
            return 1.0 * dy
        }
    )
}

main() {
    // The AD system will use the `localDFoo` as the
    // adjoint of `a.foo` for differentiation, since
    // it has higher priority than the `a.dfoo`
    let gradRes = @Grad(foo, 2.0)
    print(gradRes.toString())     // Prints 1.000000
    0
}
//================================= file A
package a

@Differentiable
public func foo(x: Float64): Float64 {
    return x
}

//================================= file B
package b
import a.*

@Adjoint [primal: foo]
public func dfooB(x: Float64): (Float64, (Float64)->Float64) {
    return (
        x,
        { dy: Float64 =>
            return 1.0 * dy
        }
    )
}

//================================= file C
package c
import a.*

@Adjoint [primal: foo]
public func dfooC(x: Float64): (Float64, (Float64)->Float64) {
    return (
        x,
        { dy: Float64 =>
            return 1.0 * dy
        }
    )
}

//================================= file D
import a.*, b.*, c.*

main() {
    // Compilation Error: multiple imported adjoint for function foo are found
    let grad_res = @Grad(foo, 2.0)
    0
}

高阶微分

用户可以在微分函数中使用 @Grad@ValWithGrad@AdjointOf@VJP等表达式来实现高阶微分的效果。

此时用户需要在 @Differentiable 标注中额外提供 stage 信息来标记该微分函数的最高微分阶数(若未提供则默认最高微分阶数为一阶),若阶数信息不正确则将引发编译器报错。

注:目前高阶微分最高支持 2 阶,即 stage 取值只能是 12。 阶数检查规则:

  1. stage=1 的函数中可以使用 stage=1,2 的可微函数调用表达式,可以包含 stage=2 的函数的微分表达式
  2. stage=2 的函数中可以使用 stage=2 的函数调用表达式
@Differentiable [stage: 2]
func f1(x: Float64) {
    x * x * x       // Will be differentiated as `df1/dx = 3 * x * x`
}

@Differentiable
func f2(x: Float64) {
    let dx = @Grad(f1, x)      // Will be differentiated as `df2/dx = d(df1/dx)/dx = 6 * x`
    dx
}

main() {
    let x: Float64 = 1.0
    let firstOrderGrad = @Grad(f1, x)
    let secondOrderGrad = @Grad(f2, x)
    println(firstOrderGrad.toString())     // Prints 3.000000
    println(secondOrderGrad.toString())    // Prints 6.000000
    return 0
}

需要注意的是,不允许在微分函数中对函数本身或将调用函数本身的函数使用 @Grad@ValWithGrad@AdjointOf@VJP 表达式,否则将产生循环依赖从而引发编译器报错。