在 Julia 中对符号矩阵进行羔羊化

计算科学 微分方程 符号计算 朱莉娅 同情
2021-12-13 02:50:12

如果我有一个如下定义的符号矩阵T,是否有任何方法可以将其作为变量的函数,例如σ...,并以不分配的方式返回一个矩阵(除了返回的矩阵的分配)?

using SymPy, BenchmarkTools

σ = [symbols("σ_$i$j") for i in 1:4, j in 1:4];
T = Array{Sym}(undef,4,4);
for i in 1:4
    for j in 1:4
        T[i,j] = 1 + σ[i,j];
    end
end

直接使用SymPy'slambdifyT导致分配,而lambdify在单个元素上T不会:

f_mat = lambdify(T, invoke_latest=false);
@benchmark $f_mat(1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1)
BenchmarkTools.Trial: 
  memory estimate:  5.20 KiB
  allocs estimate:  76
  --------------
  minimum time:     17.900 μs (0.00% GC)
  median time:      19.699 μs (0.00% GC)
  mean time:        25.533 μs (22.15% GC)
  maximum time:     44.889 ms (99.92% GC)
  --------------
  samples:          10000
  evals/sample:     1

f = lambdify(T[1,1], invoke_latest=false)
@benchmark f(1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     13.726 ns (0.00% GC)
  median time:      14.729 ns (0.00% GC)
  mean time:        14.931 ns (0.00% GC)
  maximum time:     75.751 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     998

对于上下文,我正在尝试使用 DifferentialEquations.jl 实现一组 ODE,并且正在寻找一种快速方法来更新表示 ODE 的矩阵(具体来说,我希望在 ODE 随时间变化时更新它们)。我当然可以简单地循环 ODElambdify的所有单个函数T,但我的理解是,这种方法通常会随着 ODE 的数量(或等效地,中的元素数量T)线性缩放。一个警告是,各个功能在T一般来说可能会有很大的不同,所以我不确定这里可以使用任何矢量化,所以非矢量化循环可能是我能做的最好的。我对 Julia 和这类元编程都很陌生,所以我很可能以错误的方式思考这个问题,如果有人能在这里阐明我的问题,那就太好了。谢谢!

1个回答

在这种情况下,很难绕过 SymPy 隐含的分配。它想要分配矩阵,所以最简单的做法是,如您所展示的,构建单独的标量函数。但是将它们组合在一起可能会有点麻烦,因为您不想将它们放入数组中,因为它们都是不同的类型,这会破坏循环调用的潜在优化。这意味着您还必须在其上创建函数和元程序的大元组,以便通过直接构建索引元组来完全摆脱分配......这并不有趣,所以我不打算在那里展示代码但是,如果您真的需要这样做,那么解释就足够了。

进行这种元编程但最终使用非分配 ODE 函数的更简单方法是通过 ModelingToolkit。DifferentialEquations.jl DSLs 最近从 SymPy/SymEngine 切换到作为后端的 ModelingToolkit,DifferentialEquations.jl 文档很快就会建议使用这种方法来处理此类问题,因此我将解释建议的工作流程。与 SymPy 类似,您只需以编程方式创建符号表达式:

using ModelingToolkit, BenchmarkTools

@variables t σ[1:4,1:4](t)
@derivatives D'~t
eqs = Array{Equation}(undef,4,4)
for i in 1:4
    for j in 1:4
        eqs[i,j] =  D(σ[i,j]) ~ 1 + σ[i,j]
    end
end

从那里你可以告诉它从ODESystem方程中构建一个并告诉它为DifferentialEquations.jl构建原语:

sys = ODESystem(vec(eqs))
f = ODEFunction(sys)

这些原语被制成非分配快速函数:

u = rand(4,4)
du = similar(u)
@benchmark f(du,u,nothing,0.0)

BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     78.041 ns (0.00% GC)
  median time:      91.340 ns (0.00% GC)
  mean time:        108.952 ns (0.00% GC)
  maximum time:     829.795 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     970

你去那里,你也可以直接调查生成的代码:

generate_function(sys)[2]

## Generated:
:((var"##MTIIPVar#353", var"##MTKArg#349", var"##MTKArg#350", var"##MTKArg#351")->begin
          @inbounds begin
                  let (σ₁ˏ₁, σ₂ˏ₁, σ₃ˏ₁, σ₄ˏ₁, σ₁ˏ₂, σ₂ˏ₂, σ₃ˏ₂, σ₄ˏ₂, σ₁ˏ₃, σ₂ˏ₃, σ₃ˏ₃, σ₄ˏ₃, σ₁ˏ₄, σ₂ˏ₄, σ₃ˏ₄, σ₄ˏ₄, t) = (var"##MTKArg#349"[1], var"##MTKArg#349"[2], var"##MTKArg#349"[3], var"##MTKArg#349"[4], var"##MTKArg#349"[5], var"##MTKArg#349"[6], var"##MTKArg#349"[7], var"##MTKArg#349"[8], var"##MTKArg#349"[9], var"##MTKArg#349"[10], var"##MTKArg#349"[11], var"##MTKArg#349"[12], var"##MTKArg#349"[13], var"##MTKArg#349"[14], var"##MTKArg#349"[15], var"##MTKArg#349"[16], var"##MTKArg#351")
                      var"##MTIIPVar#353"[1] = 1 + σ₁ˏ₁
                      var"##MTIIPVar#353"[2] = 1 + σ₂ˏ₁
                      var"##MTIIPVar#353"[3] = 1 + σ₃ˏ₁
                      var"##MTIIPVar#353"[4] = 1 + σ₄ˏ₁
                      var"##MTIIPVar#353"[5] = 1 + σ₁ˏ₂
                      var"##MTIIPVar#353"[6] = 1 + σ₂ˏ₂
                      var"##MTIIPVar#353"[7] = 1 + σ₃ˏ₂
                      var"##MTIIPVar#353"[8] = 1 + σ₄ˏ₂
                      var"##MTIIPVar#353"[9] = 1 + σ₁ˏ₃
                      var"##MTIIPVar#353"[10] = 1 + σ₂ˏ₃
                      var"##MTIIPVar#353"[11] = 1 + σ₃ˏ₃
                      var"##MTIIPVar#353"[12] = 1 + σ₄ˏ₃
                      var"##MTIIPVar#353"[13] = 1 + σ₁ˏ₄
                      var"##MTIIPVar#353"[14] = 1 + σ₂ˏ₄
                      var"##MTIIPVar#353"[15] = 1 + σ₃ˏ₄
                      var"##MTIIPVar#353"[16] = 1 + σ₄ˏ₄
                  end
              end
          nothing
      end)

只是为了演示,你可以告诉它计算稀疏雅可比行列式和多线程代码:

generate_jacobian(sys,sparse=true,multithread=true)[2]

## Generated:
:((var"##MTIIPVar#363", var"##MTKArg#359", var"##MTKArg#360", var"##MTKArg#361")->begin
          @inbounds begin
                  let (σ₁ˏ₁, σ₂ˏ₁, σ₃ˏ₁, σ₄ˏ₁, σ₁ˏ₂, σ₂ˏ₂, σ₃ˏ₂, σ₄ˏ₂, σ₁ˏ₃, σ₂ˏ₃, σ₃ˏ₃, σ₄ˏ₃, σ₁ˏ₄, σ₂ˏ₄, σ₃ˏ₄, σ₄ˏ₄, t) = (var"##MTKArg#359"[1], var"##MTKArg#359"[2], var"##MTKArg#359"[3], var"##MTKArg#359"[4], var"##MTKArg#359"[5], var"##MTKArg#359"[6], var"##MTKArg#359"[7], var"##MTKArg#359"[8], var"##MTKArg#359"[9], var"##MTKArg#359"[10], var"##MTKArg#359"[11], var"##MTKArg#359"[12], var"##MTKArg#359"[13], var"##MTKArg#359"[14], var"##MTKArg#359"[15], var"##MTKArg#359"[16], var"##MTKArg#361")
                      begin
                          Threads.@spawn begin
                                  (var"##MTIIPVar#363").nzval[1] = 1
                                  (var"##MTIIPVar#363").nzval[2] = 1
                                  (var"##MTIIPVar#363").nzval[3] = 1
                                  (var"##MTIIPVar#363").nzval[4] = 1
                              end
                      end
                      begin
                          Threads.@spawn begin
                                  (var"##MTIIPVar#363").nzval[5] = 1
                                  (var"##MTIIPVar#363").nzval[6] = 1
                                  (var"##MTIIPVar#363").nzval[7] = 1
                                  (var"##MTIIPVar#363").nzval[8] = 1
                              end
                      end
                      begin
                          Threads.@spawn begin
                                  (var"##MTIIPVar#363").nzval[9] = 1
                                  (var"##MTIIPVar#363").nzval[10] = 1
                                  (var"##MTIIPVar#363").nzval[11] = 1
                                  (var"##MTIIPVar#363").nzval[12] = 1
                              end
                      end
                      begin
                          Threads.@spawn begin
                                  (var"##MTIIPVar#363").nzval[13] = 1
                                  (var"##MTIIPVar#363").nzval[14] = 1
                                  (var"##MTIIPVar#363").nzval[15] = 1
                                  (var"##MTIIPVar#363").nzval[16] = 1
                              end
                      end
                  end
              end
          nothing
      end)

为了结束讨论,您可以ODEProblem在这样的系统上使用构造函数来生成和求解 ODE。请注意,与普通的DifferentialEquations.jl 语法不同,您在这里为初始条件提供一个数组,在这里,为了从符号转换为数字,您提供了一个对数组来告诉它如何将符号与初始条件匹配。

using OrdinaryDiffEq
u0 = [σ[i,j]=>rand() for i in 1:4, j in 1:4]
p = nothing
tspan = (0.0,1.0)
prob = ODEProblem(sys,u0,tspan,p)
solve(prob,Tsit5())

虽然这条路线对于 SymPy 的功能还不完整,但如果需要缺少功能,它可以通过 SymPy 往返。

编辑:在我回答添加“(具体来说,我希望在 ODE 随时间变化时更新 ODE)”后,对问题进行了编辑。我强烈建议生成一个与时间相关的函数或使用回调来更改参数,而不是尝试在每个时间点生成一个新函数。请注意,您将无法避免在每个新函数上点击编译器的开销,因此,如果您尝试将此作为某种优化,它实际上不会使事情变得更快,而且很可能会减慢速度. 也就是说,如果这是您需要的,您可以使用ModelingToolkit.build_function界面的一部分来完成此操作,如果您描述更多关于您正在尝试做的事情,我可以举一个例子。

鉴于此处进行编辑的时间安排,我的猜测是这可能是一个 XY 问题,如果我们可以聊聊您要完成的工作,它会更容易提供帮助