JIT 加速

仓颉 TensorBoost 提供独立的 jit 函数对 Tensor 运算进行加速。jit 函数为高阶函数形式,入参和返回值类型都是函数。返回的新函数相比原函数使能更多底层优化。返回的新函数在第一次调用的时候,会尝试进行图编译,然后执行编译的图得到运算结果。后续调用该函数会复用之前编译的图,达到加速效果。

package ops

// 通用接口(tensorFunc入参为tensor数组,可由任意个数tensor组成)
public func jit(tensorFunc: (Array<Tensor>)->Tensor): (Array<Tensor>)->Tensor

// 特例化接口(tensorFunc入参接受1个、2个、3个、4个、5个tensor)
public func jit(tensorFunc: (Tensor)->Tensor): (Tensor)->Tensor
public func jit(tensorFunc: (Tensor, Tensor)->Tensor): (Tensor, Tensor)->Tensor
public func jit(tensorFunc: (Tensor, Tensor, Tensor)->Tensor): (Tensor, Tensor, Tensor)->Tensor
public func jit(tensorFunc: (Tensor, Tensor, Tensor, Tensor)->Tensor): (Tensor, Tensor, Tensor, Tensor)->Tensor
public func jit(tensorFunc: (Tensor, Tensor, Tensor, Tensor, Tensor)->Tensor): (Tensor, Tensor, Tensor, Tensor, Tensor)->Tensor

使用 JIT 加速功能一般分为两步。

  • 将某函数(可以是个匿名函数)传给 jit 函数得到新函数,新函数和原函数的输入、输出类型保持不变;
  • 调用新函数执行运算,调用方式跟正常函数调用一样,可以多次调用,每次传入相应的 tensor 数据。

使用示例:

// 示例1:对普通函数进行加速
func foo(x: Tensor, y: Tensor): Tensor { return x + y }
let jittedFoo = jit(foo)


// 示例2:对lambda函数进行加速
let jittedLambdaFunc = jit({x, y => x + y})


// 示例3:对包含非Tensor入参的函数进行加速(方法:封装偏函数)
func bar(x: Tensor, y: Tensor, attr: Int64): Tensor {
    return if (attr > 0) { x + y } else { x - y }
}

func partialBar(attr: Int64): (Tensor, Tensor)->Tensor {
    return {x: Tensor, y: Tensor => bar(x, y, attr)}
}

let jittedAddFunc = jit(partialBar(1))
let jittedSubFunc = jit(partialBar(0))

注意以下约束条件,如有违反会导致 jit 返回的新函数被调用时抛异常:

  • 调用新函数之前,不能存在一些 tensor 尚未求值;
  • 新函数调用的入参 tensor 必须是有值的,不能是 parameter;
  • 函数返回值必须是来自某个算子,不能是常量 tensor 或 parameter。

另外,仓颉 TensorBoost 还提供更底层的加速 API runWith,可以免去jit API 产生的缓存查询开销,直接基于底层图执行加速。

package ops

public interface Jitable {
    // 无参计算图
    public func runWith(tensorFunc: ()->Tensor): Tensor
    // 有参计算图
    public func runWith(ins: Array<Tensor>, tensorFunc: (Array<Tensor>)->Tensor): Tensor
}

// 仓颉 TensorBoost 计算图数据类型
public class Graph {
    public init()
}

extend Graph <: Jitable {
    public func runWith(tensorFunc: ()->Tensor): Tensor
    public func runWith(ins: Array<Tensor>, tensorFunc: (Array<Tensor>)->Tensor): Tensor
}

使用示例:

let g0 = Graph()
var acc = Tensor(0.0)
for (i in 0..4) {
    acc = g0.runWith([acc]) {ins => ins[0] + Tensor(1.0)}
}

jit 底层调用 runWith。比较两个 API 之间异同:

【相同点】

都是高阶函数形式,支持对传入的函数以静态图方式执行,底层和某个图对象进行绑定起到图编译加速作用。

【不同点】

  • runWith 不会从动态图切换到静态图。如果上下文是动态图模式,runWith 会按照动态图方式执行。而 jit 不管上下文,固定对传入函数进行静态图优化。
  • runWith 基于 Graph 绑定某段 tensor 运算,绑定成功后可以直接运行该图得到结果。jit 的图对象缓存在全局 hash 表,通过 tensor shape、dtype 等查找图缓存。因此,runWithjit 更高效,介意图缓存开销的可以选择用 runWith