自动微分
前言
自动微分是一种对程序中的函数计算导数的技术。相比符号微分,其避免了表达式膨胀的性能问题;而相比数值微分,其解决了近似误差的正确性问题。
关于自动微分技术本身的相关背景知识,我们将不在本手册中进行详细介绍。推荐用户阅读以下参考文献进行详细了解。
- How to Differentiate with a Computer
- Automatic Differentiation in Machine Learning: a Survey
- Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation
在仓颉编程语言中,自动微分将作为原生语言特性被提供给用户,用户可以通过配置编译选项 --enable-ad
来开启自动微分特性。
cjc --enable-ad helloworld.cj
需要说明的是,目前仓颉自动微分特性仅支持反向模式自动微分。
可微类型
在仓颉自动微分的库中,我们提供了Differentiable
interface 用于定义仓颉中数据类型的微分规则。所有参与自动微分中导数计算的数据类型均应实现该 interface 以确保被自动微分系统认为是合法的可微类型。
// Defined in the AD library
public interface Differentiable<T> where T <: Differentiable<T> {
// Initialize a zero tangent
func TangentZero(): T
// Sum up the tangent value `y`
func TangentAdd(y: T): T
}
对于可微类型(例如浮点数类型、元组类型、struct
类型),编译器会自动生成合理的 interface 实例,对于可微struct
类型,用户也可以通过手动实现该 interface 实例。
可微数值类型
仓颉中的可微数值类型包括三种:
Float16
Float32
Float64
可微元组类型
当元组中所有元素均为可微类型时,则该元组是可微元组类型。
let a: (Float64, Float64) = (1.0, 1.0) // differentiable
let b: (Float64, String) = (1.0, "foo") // NOT differentiable
可微 struct 类型
默认情况下,struct
类型不可微,当 struct
类型不包含静态成员变量时,用户可在 struct
类型定义上方增加 @Differentiable
标注,将其定义为可微并通过 except
列表配置该 struct
类型的微分行为。
给定 except
列表后,该 struct
类型的成员变量将分为两类。
- 不在
except
列表中的成员变量,必须为不可变变量,并且其类型为可微类型,成员变量将参与该struct
类型对象的微分过程 - 在
except
列表中的成员变量将不参与该struct
类型对象的微分过程。在微分过程中,该成员变量将被保持为未初始化状态。因此用户应确保不访问微分结果中的这些成员变量值,否则将导致未定义行为
// Differentiable
@Differentiable
struct Point {
let x: Float64
let y: Float64
init(x: Float64, y: Float64) {
this.x = x
this.y = y
}
}
// Differentiable, but no back-propagation will happen in the excepted field `tag`
@Differentiable [except: [tag]]
struct TaggedPoint {
let x: Float64
let y: Float64
let tag: String
init(x: Float64, y: Float64, tag: String) {
this.x = x
this.y = y
this.tag = tag
}
}
// Compilation Error: variable `tag` has non-differentiable type String,
// but it does not appear in an except list
@Differentiable
struct TaggedWrong {
let x: Float64
let y: Float64
let tag: String
init(x: Float64, y: Float64, tag: String) {
this.x = x
this.y = y
this.tag = tag
}
}
我们也提供了定义 except
列表的另外一个语法糖版本,用户可以通过以下语法定义 struct
类型的 include
列表。此时,struct
所有不在该 include
列表中的成员变量都被定义为在 except
列表中。
// Differentiable, but no back-propagation will happen in the excepted field `tag`
@Differentiable [include: [x, y]]
struct TaggedPoint {
let x: Float64
let y: Float64
let tag: String
init(x: Float64, y: Float64, tag: String) {
this.x = x
this.y = y
this.tag = tag
}
}
可微 unit 类型
Unit
类型是一种特殊的可微类型。对任何 Unit
类型对象的微分操作所得结果均仍是 Unit
类型。
不可微类型
仓颉自动微分暂不支持对 String
、Array
、enum
、Class
、 Int16
、 Int32
、 Int64
和 Interface
类型数据的微分,故这些类型均为不可微类型。
可微函数
默认情况下,函数均为不可微,但用户可以在函数定义上方增加 @Differentiable
标注,将其定义为可微并通过 except
列表配置该函数的微分行为。
给定 except
列表后,该函数的参数将分为两类。
- 不在
except
列表中的参数,必须为可微类型。该参数将参与函数的微分过程,微分结果中将包含函数相对该参数的导数 - 在
except
列表中的参数将不参与函数的微分过程。在微分过程中,该类参数及依赖该参数的中间变量均将被忽略,微分结果中将不包含函数相对该参数的导数
// Differentiable function, its derivatives will be calculated with respect to `x` and `y`
@Differentiable
func f(x: Float64, y: Float64) {
return x * y
}
// Differentiable function, its derivative will only be calculated with respect to `x`
@Differentiable [except: [y]]
func f(x: Float64, y: Float64) {
return x * y
}
// Compilation Error: input `z` has non-differentiable type String,
// but it does not appear in an except list
@Differentiable
func f(x: Float64, y: Float64, z: String) {
return x + y
}
相似地,我们也提供了定义 except
列表的另外一个语法糖版本,用户可以通过以下语法定义函数的 include
列表。此时,函数所有不在该 include
列表中的参数都被定义为在 except
列表中。except
和 include
只能出现一个。
// Differentiable function, its derivative will only be calculated with respect to `x`
@Differentiable [include: [x]]
func f(x: Float64, y: Float64) {
return x * y
}
在使用上述语法进行可微函数定义时,用户还需要确保函数满足以下条件。
- 函数的返回类型为可微类型
- 函数参数为可微类型,且不在
except
列表中(或在include
列表中) - 函数中未使用全局变量
- 函数中未使用静态变量
- 函数体中所有表达式均为可微函数合法表达式,即表达式需满足以下任一条件
- 表达式可微,即其微分规则已知。目前仓颉支持的可微表达式如下:
- 赋值表达式
- 算术表达式
- 流表达式
- 条件表达式(
if
) - 循环表达式(仅
while
,do-while
),并且不支持使用continue
,break
,return
- lambda 表达式
- 可微
Tuple
类型对象的初始化表达式 - 可微
Tuple
类型对象的解构和下标访问表达式 - 对可微函数的函数调用表达式,且在编译期可以确定哪一个可微函数定义被调用
- 表达式不直接或间接地与不在
except
列表中的函数参数(包括嵌套函数的外层作用域函数参数)有数据依赖关系。若必须产生数据依赖关系,则可以使用 [stopGradient
函数接口] 来中止梯度传播 - 可微函数(嵌套函数或匿名函数)中
return
表达式(包括自动插入的return
表达式)只能出现一次 - 嵌套函数或匿名函数未捕获可变变量
- 表达式可微,即其微分规则已知。目前仓颉支持的可微表达式如下:
// Not differentiable
func f1(x: Float64, y: Float64): Float64 {
return x + y
}
// Compilation Error: expression x is marked for back-propagation,
// but is used in a context that cannot be back-propagated
@Differentiable
func f2(x: Float64, y: Float64): Float64 {
return f1(x, y)
}
// Compilation Error: match expression is not supported in differentiable function
@Differentiable
func f3(x: Float64) {
var a = match (x) {
case _ => 1.0
}
return a
}
@Differentiable
func f4 (x: Float64) {
var y = 1.0
// Compilation Error: capture mutable variable y in nested function is not supported by AD yet
func nested (x: Float64) {
y = y + x
}
nested(x)
}
@Differentiable [except: [n]]
func g(m: Float64, n: Float64) {
return m + n
}
@Differentiable
func f5(x: Float64, y: Float64) {
// Use `stopGradient` to stop the gradient propagation so we can
// pass the `y` in `f5` to `n` in `g` even `n` is marked as `except`
return g(x, stopGradient<Float64>(y))
}
func g(m: Float64) {
return m == 1.0
}
@Differentiable
func f6(x: Float64) {
// Use `stopGradient` to stop the gradient propagation so we can
// pass the `x` in `f6` to non-differentiable function `g`
let cond = g(stopGradient<Float64>(x))
if (cond) {
x * 2.0
} else {
x * 3.0
}
}
注意:为了方便计算 lambda 表达式捕获变量的导数,我们在仓颉自动微分的库中预定义了以下两个接口,这两个接口将被仓颉的自动微分实现模块调用,不建议用户直接使用。
/* lambda 表达式捕获变量的导数组成的环境变量类型的基类 */
public abstract class EnvironmentTangent {
/* 合并两部分环境变量导数 */
public func TangentAdd(other: EnvironmentTangent): EnvironmentTangent
}
/* 环境变量类型对应的 0 梯度类型 */
public class EnvironmentZero <: EnvironmentTangent {
public func TangentAdd(other: EnvironmentTangent): EnvironmentTangent {
return other
}
}
自定义函数微分规则
用户还可以通过手动提供伴随函数来为可微函数自定义微分规则。给定一个可微的全局函数 f
(称为:源函数),用户可通过在另一个函数 g
的函数定义上方添加 @Adjoint
标注,将其指定为 f
的自定义伴随函数。
伴随函数需满足以下条件:
- 伴随函数的输入参数的数量、类型和顺序与原可微函数完全一致
- 伴随函数的输出是一个包含两个子元素的元组
- 元组第一个子元素为原可微函数在当前伴随函数的参数输入下的输出结果
- 元组第二个子元素为一个函数,作为原可微函数的梯度反向传播器
- 梯度反向传播器的输入类型与原可微函数的输出数量、类型和顺序一致
- 梯度反向传播器的输出类型与原可微函数的输入数量、类型和顺序一致
// Function `f` is differentiable and its custom adjoint function is `g`
@Differentiable
func f(x: Float64): Float64 {
return x * x * x
}
@Adjoint [primal: f]
func g(x: Float64): ((Float64), ((Float64) -> Float64)) {
let xSquare = x * x
return (
xSquare * x,
{ dy: Float64 =>
return dy * xSquare * 3.0
}
)
}
当用户通过上述语法为源函数自定义伴随函数后,自动微分系统将直接使用该自定义伴随函数进行微分求导。需要注意的是我们当前暂不支持用户对参数或返回值包含函数类型的可微函数自定义伴随函数。
非全局可微函数
需要注意的是,上述可微函数规则默认应用于全局函数。事实上,仓颉编程语言中存在多种非全局函数。在自动微分系统中,我们支持以下非全局函数被定义为可微函数,其他类型的非全局函数均不支持定义为可微函数。
struct 构造函数和成员函数
用户可以使用相同的可微函数标注 @Differentiable
将 struct
中的构造函数和成员函数定义为可微。
struct
构造函数标注为可微。定义构造函数为可微的前提是struct
类型本身可微。在这种情况下,调用该构造函数的struct
初始化表达式可微。相反地,若构造函数未被标注为可微,则调用该构造函数的struct
初始化表达式不可微struct
成员函数标注为可微。根据仓颉语言定义,struct
成员函数隐藏包含了一个标识符为this
的函数参数用于表示struct
对象本身,故用户也可以在except
列表中配置this
参数来决定是否需要对this
进行微分。
@Differentiable
struct Point {
let a: Float64
let b: Float64
// Differentiable constructor
@Differentiable
init(x: Float64) {
a = x
b = x
}
}
@Differentiable
func foo(x: Float64) {
// The initialization will call the differentiable constructor.
// Therefore it is differentiable here.
return Point(x)
}
@Differentiable
struct Point {
let a: Float64
let b: Float64
init(x: Float64) {
a = x
b = x
}
// Differentiable method
@Differentiable
public func sum(bias: Float64): Float64 {
return a + b + bias
}
}
Class 成员函数
用户也可以使用相同的可微函数标注 @Differentiable
将 class
中的非open
成员函数定义为可微。同样根据仓颉语言定义,class
成员函数隐藏包含了一个标识符为 this
的函数参数用于表示 class
对象本身。由于在仓颉中由于 class
类型本身不可微,故除静态成员函数外,用户必须将 this
加入到可微成员函数的 except
列表中。另外需要注意的是,返回类型为This
的成员函数因返回类型不可微,所以无法被定义成可微。
class Point {
let a: Float64
let b: Float64
init(x: Float64) {
a = x
b = x
}
// Function foo can not be defined as differentiable
func foo(): This {
return this
}
// Differentiable method. But user must add `this` in the `except` list here.
@Differentiable [except: [this]]
func sum(bias: Float64): Float64 {
return a + b + bias
}
}
拓展的成员函数
用户也可以使用相同的可微函数标注语法将扩展的成员函数定义为可微,支持直接扩展和接口扩展。
interface I {
func goo(a: Float64, b: Float64): Float64
}
@Differentiable
struct Foo {}
// direct extensions
extend Foo {
@Differentiable
func foo(a: Float64, b: Float64) {
return a + b
}
}
// interface extensions
extend Foo <: I {
@Differentiable
public func goo(a: Float64, b: Float64) {
return a + b
}
}
如果需要扩展的类型不可微,用户必须将this
加入到扩展的可微成员函数的except
列表中。
interface I {
func goo(a: Float64, b: Float64): Float64
}
struct Foo {}
// direct extensions
extend Foo {
@Differentiable[except: [this]]
func foo(a: Float64, b: Float64) {
return a + b
}
}
// interface extensions
extend Foo <: I {
@Differentiable[except: [this]]
public func goo(a: Float64, b: Float64) {
return a + b
}
}
微分表达式
@Grad
表达式
给定一个可微函数和相应的输入参数值,用户可使用 @Grad
表达式来获取该函数在该输入值处的梯度值。
@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
return if (negate) { - x * y } else { x * y }
}
main(): Int64 {
// Since `negate` is excepted, the gradient of product only has two components
let productGrad = @Grad(product, 2.0, 3.0, true)
print(productGrad[0].toString()) // Prints -3.000000
print(productGrad[1].toString()) // Prints -2.000000
return 0
}
使用@Grad
表达式时有以下注意事项
@Grad
表达式只能作为初始值用于 var 或 let 变量初始化表达式中- 给定的
diffFunc
标识符必须表示一个函数,且满足以下条件- 必须是全局函数
- 必须是可微函数
- 函数不能有命名参数和参数默认值
- 函数的返回类型只能是一下类型之一:
Float16
,Float32
,Float64
inputVal
组成的集合必须与diffFunc
表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)
的函数调用匹配规则相同- 给定上述类型为
(X1, X2, ..., Xm)->Y
的函数diffFunc
- 若其
except
列表为空,则@Grad
表达式的类型为(X1, X2, ..., Xm)
。若该Tuple
类型元素数量为 0,则退化为Unit
类型,若该Tuple
类型仅包含一个元素Xj
,则退化为Xj
类型 - 若其
except
列表不为空,则@Grad
表达式的类型中需相应地剔除在except
列表中的函数参数类型。假定except
列表中包含的参数为Xj
,则@Grad
表达式的类型为(X1, X2, ..., Xj-1, Xj+1, Xm)
- 若其
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
return x * y
}
main() {
let res = @Grad(foo, 2.0, 3.0) // Ok
// Compilation Error: @Grad expr must be used in var decl with an identifier
print(@Grad(foo, 2.0, 3.0)[0].toString())
// Compilation Error: @Grad expr must be used in var decl with an identifier
let (gradx, grady) = @Grad(foo, 2.0, 3.0)
return 0
}
@Differentiable
struct A {
@Differentiable
public func foo(x: Float64, y: Float64) {
return x + y
}
}
main() {
let a = A()
// Compilation Error: the function is not a top-level differentiable function identifier
let temp = @Grad(a.foo, 1.0, 1.0)
return 0
}
@ValWithGrad
表达式
给定一个可微函数和相应的输入参数值,用户可使用 @ValWithGrad
表达式来获取该函数在该输入值处的结果和梯度值。
@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
return if (negate) { - x * y } else { x * y }
}
main(): Int64 {
let productValWithGrad = @ValWithGrad(product, 2.0, 3.0, true)
let (productRes, productGrad) = productValWithGrad
print(productRes.toString()) // Prints -6.000000
print(productGrad[0].toString()) // Prints -3.000000
print(productGrad[1].toString()) // Prints -2.000000
return 0
}
使用@ValWithGrad
表达式时有以下注意事项。
@ValWithGrad
表达式只能作为初始值用于 var 或 let 变量初始化表达式中- 给定的
diffFunc
标识符必须表示一个函数,且满足以下条件- 必须是全局函数
- 必须是可微函数
- 函数不能有命名参数和参数默认值
- 函数的返回类型只能是一下类型之一:
Float16
,Float32
,Float64
inputVal
组成的集合必须与diffFunc
表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)
的函数调用匹配规则相同- 给定上述类型为
(X1, X2, ..., Xm)->Y
的函数diffFunc
- 若其
except
列表为空,则@ValWithGrad
表达式的类型为(Y, (X1, X2, ..., Xm))
。若该Tuple
类型元素数量为 0,则退化为(Y, Unit)
类型,若该Tuple
类型仅包含一个元素Xj
,则退化为(Y, Xj)
类型 - 若其
except
列表不为空,则@ValWithGrad
表达式的类型中需相应地剔除在except
列表中的函数参数类型。假定except
列表中包含的参数为Xj
,则@ValWithGrad
表达式的类型为(Y, (X1, X2, ..., Xj-1, Xj+1, Xm))
- 若其
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
return x * y
}
main() {
let res = @ValWithGrad(foo, 2.0, 3.0) // Ok
// Compilation Error: @ValWithGrad expr must be used in var decl with an identifier
print(@ValWithGrad(foo, 2.0, 3.0)[0].toString())
// Compilation Error: @ValWithGrad expr must be used in var decl with an identifier
let (res, (gradx, grady)) = @ValWithGrad(foo, 2.0, 3.0)
return 0
}
@Differentiable
struct A {
@Differentiable
public func foo(x: Float64, y: Float64) {
return x + y
}
}
main() {
let a = A()
// Compilation Error: the function is not a top-level differentiable function identifier
let temp = @ValWithGrad(a.foo, 1.0, 1.0)
return 0
}
@AdjointOf
表达式
给定一个可微函数,用户还可以使用 @AdjointOf
表达式来获取对该函数微分产生的伴随函数。
@AdjointOf
表达式只能作为初始值用于 var 或 let 变量初始化表达式中- 给定的
diffFunc
标识符必须表示一个函数,且满足以下条件- 必须是全局函数
- 必须是可微函数
- 函数不能有命名参数和参数默认值
- 给定上述类型为
(X1, X2, ..., Xm)->Y
的函数diffFunc
- 若其
except
列表为空,则@AdjointOf
表达式的类型为(X1, X2, ..., Xm)->(Y, (Y) -> (X1, X2, ..., Xm))
。若该Tuple
类型元素数量为 0,则退化为(X1, X2, ..., Xm)->(Y, (Y) -> Unit)
类型,若该Tuple
类型仅包含一个元素Xj
,则退化为(X1, X2, ..., Xm)->(Y, (Y) -> Xj)
类型 - 若其
except
列表不为空,则@AdjointOf
表达式的类型中需相应地剔除在except
列表中的函数参数类型。假定except
列表中包含的参数为Xj
,则@AdjointOf
表达式的类型为(X1, X2, ..., Xm)->(Y, (Y) -> (X1, X2, ..., Xj-1, Xj+1, Xm))
- 若其
@Differentiable
func foo(x: Float64, y: Float64): Float64 {
return x * y
}
main() {
// Get the adjoint function of `foo`
let fooAdj = @AdjointOf(foo)
// Given the value of `x` as 2.0 and `y` as 3.0, the adjoint function
// will return:
// 1) the result of `foo` when `x = 2.0` and `y = 3.0`
// 2) an back-propagator which propagates the gradient from output to input for `foo`
let res = fooAdj(2.0, 3.0)
let fooRes = res[0] // Prints 6.000000
let fooBP = res[1] // The back-propagator
let (dx, dy) = fooBP(1.0)
print(dx.toString()) // Prints 3.000000
print(dy.toString()) // Prints 2.000000
}
@VJP
表达式
给定一个可微函数和相应的输入参数值,用户可使用@VJP
表达式来获取该函数在该输入值处的结果和反向传播函数。
使用@VJP
表达式时有以下注意事项。
- 给定的
diffFunc
标识符必须表示一个函数,且满足以下条件- 必须是全局函数
- 必须是可微函数
- 函数不能有命名参数和参数默认值
inputVal
组成的集合必须与diffFunc
表示的函数的参数匹配,匹配规则与形为diffFunc(inputVal, ...)
的函数调用匹配规则相同- 给定上述类型为
(X1, X2, ..., Xm)->Y
的函数diffFunc
- 若其
except
列表为空,则@VJP
表达式的类型为(Y, (Y) -> (X1, X2, ..., Xm))
。若该Tuple
类型元素数量为 0,则退化为(Y, (Y) -> Unit)
类型,若该Tuple
类型仅包含一个元素Xj
,则退化为(Y, (Y) -> Xj)
类型 - 若其
except
列表不为空,则@VJP
表达式的类型中需相应地剔除在except
列表中的函数参数类型。假定except
列表中包含的参数为Xj
,则@VJP
表达式的类型为(Y, (Y) -> (X1, X2, ..., Xj-1, Xj+1, Xm))
- 若其
@Differentiable [except: [negate]]
func product(x: Float64, y: Float64, negate: Bool): Float64 {
return if (negate) { - x * y } else { x * y }
}
main(): Int64 {
let productVJP = @VJP(product, 2.0, 3.0, true)
let (productRes, productBP) = productVJP
print(productRes.toString()) // Prints -6.000000
let productGrad = productBP(1.0)
print(productGrad[0].toString()) // Prints -3.000000
print(productGrad[1].toString()) // Prints -2.000000
return 0
}
stopGradient
函数接口
在可微函数中,用户还可以使用 stopGradient
函数接口来强制中止某个变量或中间结果上的梯度传播。stopGradient
函数实现为一个泛型函数,可接受任意类型数据输入并将其直接返回。因此,该函数接口的使用不影响原可微函数的执行逻辑,但自动微分系统将识别该函数接口,并中止函数参数 x
对应变量或中间结果的梯度传播。此外,stopGradient
函数作用于函数时,可以把可微函数变成不可微函数。
public func stopGradient<T>(x: T) {
return x
}
@Differentiable
func foo(x: Float64) {
let t0 = x * 2.0
let t1 = x * 3.0
return t0 + t1 // Both `t0` and `t1` will propagate gradient to `x`
}
@Differentiable
func goo(x: Float64) {
let t0 = x * 2.0
let t1 = x * 3.0
return t0 + stopGradient<Float64>(t1) // Only `t0` will propagate gradient to `x`
}
main() {
let res0 = @Grad(foo, 1.0) // `res0` will be 5.0
let res1 = @Grad(goo, 1.0) // `res1` will be 2.0
}
伴随函数的导入/导出
给定一个在包中定义的源函数,自动微分系统将对其微分并在该包中生成它的伴随函数。该伴随函数将具有和源函数相同的 public
属性。用户可以通过使用 import a.*
确保伴随函数与源函数拥有一致的导入/导出行为,即当用户导入源函数时其伴随函数将一并被导入,从而允许用户在当前包对该函数进行微分操作。
//================================= file A
package a
// AD system will generate `fooAdj` as the adjoint of `foo`
// `fooAdj` also has `public` attribute
@Differentiable
public func foo(x: Float64) {
return x
}
//================================= file B
package b
// Will also import a.fooAdj implicitly
import a.*
main() {
// The AD system will use the `a.fooAdj` as the adjoint of `a.foo` for differentiation
let gradRes = @Grad(foo, 2.0)
print(gradRes.toString()) // Prints 1.000000
0
}
需要注意的是,若用户实现了自定义伴随函数,则也需要为该伴随函数手动配置 public
属性,从而确保该伴随函数也会被同步导入。除此之外,在导入/导出场景下,给定一个源函数,多个来自不同来源的伴随函数有可能同时出现。在这种情况下,我们定义了如下规则来确定不同版本伴随函数的优先级和使用规则。
- 当前包内定义的本地伴随函数比从其他包导入的伴随函数优先级更高,后者将被前者屏蔽
- 若出现多个从其他包导入的不同版本伴随函数且无本地伴随函数时,将触发编译器报错
//================================= file A
package a
@Differentiable
public func foo(x: Float64): Float64 {
return x
}
@Adjoint [primal: foo]
public func dfoo(x: Float64): (Float64, (Float64)->Float64) {
return (
x,
{ dy: Float64 =>
return 1.0 * dy
}
)
}
//================================= file B
import a.*
@Adjoint [primal: foo]
func localDFoo(x: Float64): (Float64, (Float64)->Float64) {
return (
x,
{ dy: Float64 =>
return 1.0 * dy
}
)
}
main() {
// The AD system will use the `localDFoo` as the
// adjoint of `a.foo` for differentiation, since
// it has higher priority than the `a.dfoo`
let gradRes = @Grad(foo, 2.0)
print(gradRes.toString()) // Prints 1.000000
0
}
//================================= file A
package a
@Differentiable
public func foo(x: Float64): Float64 {
return x
}
//================================= file B
package b
import a.*
@Adjoint [primal: foo]
public func dfooB(x: Float64): (Float64, (Float64)->Float64) {
return (
x,
{ dy: Float64 =>
return 1.0 * dy
}
)
}
//================================= file C
package c
import a.*
@Adjoint [primal: foo]
public func dfooC(x: Float64): (Float64, (Float64)->Float64) {
return (
x,
{ dy: Float64 =>
return 1.0 * dy
}
)
}
//================================= file D
import a.*, b.*, c.*
main() {
// Compilation Error: multiple imported adjoint for function foo are found
let grad_res = @Grad(foo, 2.0)
0
}
高阶微分
用户可以在微分函数中使用 @Grad
,@ValWithGrad
,@AdjointOf
,@VJP
等表达式来实现高阶微分的效果。
此时用户需要在 @Differentiable
标注中额外提供 stage
信息来标记该微分函数的最高微分阶数(若未提供则默认最高微分阶数为一阶),若阶数信息不正确则将引发编译器报错。
注:目前高阶微分最高支持 2 阶,即
stage
取值只能是1
和2
。 阶数检查规则:
stage=1
的函数中可以使用stage=1,2
的可微函数调用表达式,可以包含stage=2
的函数的微分表达式stage=2
的函数中可以使用stage=2
的函数调用表达式
@Differentiable [stage: 2]
func f1(x: Float64) {
x * x * x // Will be differentiated as `df1/dx = 3 * x * x`
}
@Differentiable
func f2(x: Float64) {
let dx = @Grad(f1, x) // Will be differentiated as `df2/dx = d(df1/dx)/dx = 6 * x`
dx
}
main() {
let x: Float64 = 1.0
let firstOrderGrad = @Grad(f1, x)
let secondOrderGrad = @Grad(f2, x)
println(firstOrderGrad.toString()) // Prints 3.000000
println(secondOrderGrad.toString()) // Prints 6.000000
return 0
}
需要注意的是,不允许在微分函数中对函数本身或将调用函数本身的函数使用 @Grad
,@ValWithGrad
,@AdjointOf
, @VJP
表达式,否则将产生循环依赖从而引发编译器报错。