Avoiding array mutation with vector combinations for use in flux

Hi all,

I’m working on a parameter estimation problem of an ODE system with two variables u1 and u2. As one part of an objective that I will use in the loss function, I want to calculate a vector of the product of combinations with repetitions of u1 and u2 .

For example, I want v where v = [[u1 .* u1 .* u1]; [u1 .* u1 .* u2]; [u1 .* u2 .* u2]; ... ; [u2 .* u2 .* u2]]. Then take the elementwise mean of each element of v and append to a vector stats = mean(Array(sol), dims=2), the mean of each output variable over time. The end goal is to output a vector of means of the many different combinations of output variables. Ultimately, I will be using an ODE system with many more than two output variables and might need to do larger combination groups too, so I would prefer not to hardcode these combinations with repetition.

My issue is writing code that avoids matrix mutation, which Zygote in a Flux.train!() call doesn’t like. I am using the Combinatorics library. This is my current attempt

# sol is the output of solve(ODEProblem)
function compute_objective(sol)
  # I want two elements of the vector to be the mean of output variable over time
  stats = mean(Array(sol), dims=2)
  # Slice matrix so each row (output) is a vector in a matrix
  M = sliceMatrix(sol)

  # Do all product combinations with repetition
  for c in with_replacement_combinations(M,numComboElements) 
    temp = ones(size(c[1]))
    for i in c 
        temp = temp .* i
    end 
    temp = mean(temp)
    stats = vcat(stats, temp)
  end 

  return stats
end

I think Tullio might solve this?

1 Like

Chris,

Tullio looks promising! Unfortunately, I tried implementing and ran into another with how I’m using Tullio, since there is no gradient definition for a Vector of a Vector. Is there a good way to get around this?

Another way I’m thinking of trying is trying to use Zygote.ignore() when using Combinatorics.jl functions.

My current function is now as follows and the error is attached below the code snippet:

# sol is the output of solve(ODEProblem)
function compute_objective(sol)
  # Slice matrix so each row (output) is a vector in a matrix
  M = sliceMatrix(sol)

  # Do all product permutations
  # Using Tullio, this does 3 combinations with repetition, hardcoded
  @tullio stats[i,j,k] := sum(M[i] .* M[j] .* M[k])
  println(stats)

  # Using combinatorics.jl, can you do this with pure Tullio? Error seems to occur with Zygote trying to pull gradient back through combinatorics library
  # counter = 0
  # for c in with_replacement_combinations(M,numComboElements) 
  #   counter += 1
  #   println("c", c)
  #   @tullio temp[i] := sum(c[i] .* c[j] .* c[k])
  #   println("temp #", counter, temp)
  #   # temp = ones(size(c[1]))
  #   # for i in c 
  #   #     temp = temp .* i
  #   # end 
  #   # temp = mean(temp)
  #   # stats = vcat(stats, temp)
  # end 

  # Return stats as 1d column vector
  return vec(stats)
end
ERROR: LoadError: no gradient definition here!
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Tullio.var"#tullio_back#156"{Tullio.Eval{var"#ℳ𝒶𝓀ℯ#265"{var"#𝒜𝒸𝓉!#264"}, Nothing}, Tuple{Vector{Vector{Float64}}}, Array{Float64, 3}})(Δ::Array{Float64, 3})
    @ Tullio ~/.julia/packages/Tullio/u7Tk0/src/eval.jl:52
  [3] ZBack

Reduce that to a matrix?

1 Like

This worked like a charm! My final solution is

function compute_objective(sol)
  # Using Tullio, this does 3 combinations with repetition, hardcoded
  # Is there a good way to only keep combinations with repetition that dont have the same value?
  @tullio stats[i,j,k] := sol[i,c] * sol[j,c] * sol[k,c]

  # Return stats as 1d column vector
  return vec(stats)
end