Get values of symbols inside a function

Actually, I just tested it. It doesn’t inline the constant correctly:

type Test
  x::Vector{Float64}
end

import Base: getindex
getindex(s::Test,sym::Symbol) = getindex(s,Val{sym})
getindex(s::Test,::Type{Val{:test_sym}}) = reshape(view(s.x,1:9),3,3)

function test_dispatch(s)
  s[:test_sym]
end

function test_dispatch2(s)
  s[Val{:test_sym}]
end

s = Test(collect(1:10))
@code_llvm test_dispatch(s)
@code_llvm test_dispatch2(s)

The first gives


; Function Attrs: uwtable
define %jl_value_t* @julia_test_dispatch_63361(%jl_value_t*) #0 {
top:
  %1 = call %jl_value_t*** @jl_get_ptls_states() #4
  %2 = alloca [5 x %jl_value_t*], align 8
  %.sub = getelementptr inbounds [5 x %jl_value_t*], [5 x %jl_value_t*]* %2, i64 0, i64 0
  %3 = getelementptr [5 x %jl_value_t*], [5 x %jl_value_t*]* %2, i64 0, i64 2
  %4 = bitcast %jl_value_t** %3 to i8*
  call void @llvm.memset.p0i8.i32(i8* %4, i8 0, i32 24, i32 8, i1 false)
  %5 = bitcast [5 x %jl_value_t*]* %2 to i64*
  store i64 6, i64* %5, align 8
  %6 = getelementptr [5 x %jl_value_t*], [5 x %jl_value_t*]* %2, i64 0, i64 1
  %7 = bitcast %jl_value_t*** %1 to i64*
  %8 = load i64, i64* %7, align 8
  %9 = bitcast %jl_value_t** %6 to i64*
  store i64 %8, i64* %9, align 8
  store %jl_value_t** %.sub, %jl_value_t*** %1, align 8
  %10 = getelementptr [5 x %jl_value_t*], [5 x %jl_value_t*]* %2, i64 0, i64 4
  %11 = getelementptr [5 x %jl_value_t*], [5 x %jl_value_t*]* %2, i64 0, i64 3
  store %jl_value_t* inttoptr (i64 2148817472 to %jl_value_t*), %jl_value_t** %3, align 8
  store %jl_value_t* %0, %jl_value_t** %11, align 8
  store %jl_value_t* inttoptr (i64 2240320304 to %jl_value_t*), %jl_value_t** %10, align 8
  %12 = call %jl_value_t* @jl_apply_generic(%jl_value_t** %3, i32 3)
  %13 = load i64, i64* %9, align 8
  store i64 %13, i64* %7, align 8
  ret %jl_value_t* %12
}

while the second gives

; Function Attrs: uwtable
define %jl_value_t* @julia_test_dispatch2_63423(%jl_value_t*) #0 {
top:
  %1 = call %jl_value_t* @julia_getindex_63362(%jl_value_t* %0, %jl_value_t* inttoptr (i64 2240320304 to %jl_value_t*)) #1
  ret %jl_value_t* %1
}

So I guess you do need to use s[Val{:test_sym}] instead, even if it’s a constant (I wonder what was different last time I tried?). That sucks… a macro would be a quick fix for this, but is there a way to force the compiler to infer this correctly? All of the information is there to do it, I guess it just misses this case.

With the following simplified (perhaps too simplified?) code I’m not seeing that behaviour:

julia> immutable T
           x::Int
       end

julia> Base.getindex(t::T, s::Symbol) = getindex(t, Val{s})

julia> Base.getindex(t::T, ::Type{Val{:x}}) = t.x

julia> function test(t)
           1 + t[:x]
       end
test (generic function with 1 method)

julia> t = T(1)
T(1)

julia> test(t)
2

Output from @code_llvm test(t) is definitely a dynamic call for me. Would be interested if you have got an example (even a long one is fine) where it does behave as you’ve described?

Yeah, macro is usually what I’d do with something like this.

Tried with 0.3 - 0.6dev and they all do pretty much the same thing with that code.

No idea, one of our compiler gurus would need to answer that one :smile:

Just using straight type dispatch or Enum types would work here too, and have slightly nicer syntax:

s[Continuous_Field1]

But that could quickly fill the namespace. But for this use case it would probably be a nice way to handle it (of course, don’t export all of these types).

But the best way is to probably build the full type

type SatelliteData <: AbstractVector
  continuous_field1
  continuous_field2
  ...
  discrete_field1
  discrete_field2
  ...
end

and just put a linear index on it. That’ll work in most Julia packages as an “array”, though it’s slightly more complicated, since what you want to do is:

function getindex(s::SatelliteData,i::Int)
  if i<10 # first 9 is a matrix
    return s.continuous_field1[i] # use the linear index on the matrix
  elseif i<25 # next is a vector of length 15
    return s.continuous_field2[i] # grab from this field
  ... # Keep going to get every field
  end
end
function getindex(s::SatelliteData,i::Int...) = getindex(s,i[1]) # you really should check size and throw an error

Then s.continuous_field1 and all of that will work correctly, but s[i] will loop through the fields which are thought of as the continuous variables and so it will work correctly in the ODE solvers (this is essentially what MultiScaleModels.jl does).

And there’s probably a million other ways to do this to. But any way which ends up with a type with a linear index will work, however you make it.

I moved this discussion out to a thread in the development category.

Hi guys!

Thanks very much for this help.

Chris,

I decided go with the approach of S[Val{:lat}] because I came up with a very simple way to define all those getindex and setindex!. Hence, it will be very easy to populate everything with the 200+ variables. The code I used is:

importall Base.Operators

################################################################################
#                           Symbols at the Workspace
################################################################################

# Define the symbols in the workspace.
vals = ["JD"       (1,);
        "year"     (1,);
        "s_i"      (1,);
        "a"        (1,);
        "e"        (1,);
        "i"        (1,);
        "w"        (1,);
        "RAAN"     (1,);
        "f"        (1,);
        "lat"      (1,);
        "lon"      (1,);
        "h"        (1,);
        "r_i"      (3,);
        "rt_i"     (3,);
        "eclipse"  (1,);
        "B_NED"    (3,);
        "Dbi"      (3,3)];


################################################################################
#                             Simulator Workspace
################################################################################

# Define the workspace.
type SimWorkspace{T}
    # State vector with the continuous variables.
    x::Array{T,1}

    function SimWorkspace(vals)
        total = 0

        for i=1:size(vals)[1]
            v = vals[i,:]

            if length(v[2]) == 1
                total += v[2][1]
            else
                total += v[2][1]*v[2][2]
            end
        end

        return new(Array(T, total))
    end
end

# Interfaces with SimWorkspace.
getindex( A::SimWorkspace,    i::Int...) = (A.x[i...])
setindex!(A::SimWorkspace, x, i::Int...) = (A.x[i...] = x)

# Define easy access to each variable in the state vector.
getindex( A::SimWorkspace,    sym::Symbol) = getindex( A,    Val{sym})
setindex!(A::SimWorkspace, x, sym::Symbol) = setindex!(A, x, Val{sym})

visize = 1
a = 0
b = 0
s = 0
sym = []
v = []

for i = 1:size(vals)[1]
    v = vals[i,:]
    sym = v[1]

    # Verify if the variable is scalar.
    if (length(v[2]) == 1 && v[2][1] == 1)
        eval(quote
             getindex(A::SimWorkspace, ::Type{Val{Symbol($sym)}}) =
                A.x[$visize]
             end)

        eval(quote
             setindex!(A::SimWorkspace, x, ::Type{Val{Symbol($sym)}}) =
                (A.x[$visize] = x)
             end)

        visize += 1

    # Verify if the variable is a vector.
    elseif (length(v[2]) == 1 && v[2][1] > 1)
        a = v[2][1]
        s = visize + a - 1
        eval(quote
             getindex(A::SimWorkspace, ::Type{Val{Symbol($sym)}}) =
                A.x[$visize:$s]
             end)

        eval(quote
             setindex!(A::SimWorkspace, x, ::Type{Val{Symbol($sym)}}) =
                (A.x[$visize:$s] = x)
             end)

        visize += v[2][1]

    # Verify if the variable is a matrix.
    elseif length(v[2]) == 2
        a = v[2][1]
        b = v[2][2]
        s = visize + a*b - 1
        eval(quote
             getindex(A::SimWorkspace, ::Type{Val{Symbol($sym)}}) =
                reshape(view(A.x,$visize:$s),$a, $b)
             end)

        eval(quote
             setindex!(A::SimWorkspace, x, ::Type{Val{Symbol($sym)}}) =
                (A.x[$visize:$s] = reshape(x, $a, $b))
             end)
        visize += v[2][1]*v[2][2]
    end
end

Everything seems to be working. I did some benchmarks using [:Dbi] and [Val{:Dbi}] and, indeed, the latter is MUCH faster.

Thanks!!

2 Likes

Hi Chris,

After implementing tons of operations using SimWorkspace. I could not solve the following error:

ERROR: LoadError: TypeError: ODEIntegrator: in uType, expected uType<:Union{AbstractArray{T,N},Number}, got Type{SimWorkspace}
 in #solve#120(::Float64, ::Bool, ::Int64, ::DiffEqBase.ExplicitRKTableau, ::Bool, ::Void, ::Symbol, ::Bool, ::Bool, ::Array{Float64,1}, ::Array{Float64,1}, ::Bool, ::Float64, ::Rational{Int64}, ::Rational{Int64}, ::Void, ::Void, ::Rational{Int64}, ::Bool, ::Void, ::Void, ::Int64, ::Float64, ::Float64, ::Bool, ::OrdinaryDiffEq.#ODE_DEFAULT_NORM, ::OrdinaryDiffEq.#ODE_DEFAULT_ISOUTOFDOMAIN, ::Bool, ::Int64, ::String, ::OrdinaryDiffEq.#ODE_DEFAULT_PROG_MESSAGE, ::Void, ::Array{Any,1}, ::DiffEqBase.#solve, ::DiffEqBase.ODEProblem{SimWorkspace,Float64,true,#sur_mode_dyn}, ::OrdinaryDiffEq.Tsit5, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}) at /home/ronan.arraes/.julia/v0.5/OrdinaryDiffEq/src/solve.jl:186
 in (::DiffEqBase.#kw##solve)(::Array{Any,1}, ::DiffEqBase.#solve, ::DiffEqBase.ODEProblem{SimWorkspace,Float64,true,#sur_mode_dyn}, ::OrdinaryDiffEq.Tsit5, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}) at ./<missing>:0 (repeats 2 times)
 in #solve#4(::Bool, ::Array{Any,1}, ::Function, ::DiffEqBase.ODEProblem{SimWorkspace,Float64,true,#sur_mode_dyn}) at /home/ronan.arraes/.julia/v0.5/DifferentialEquations/src/DifferentialEquations.jl:34
 in solve(::DiffEqBase.ODEProblem{SimWorkspace,Float64,true,#sur_mode_dyn}) at /home/ronan.arraes/.julia/v0.5/DifferentialEquations/src/DifferentialEquations.jl:27
 in include_from_node1(::String) at ./loading.jl:488
while loading /home/ronan.arraes/Área de trabalho/Work/AM1_SUR.jl/sur_mode.jl, in expression starting on line 125

Can you help me?

I could fix it! I needed to make SimWorkspace a subtype of AbstractArray. Additionally, I had to create a recursivecopy! for this new type.

Thank you very much for the help

Hi @ChrisRackauckas,

I need your help one more time. I am almost finishing the first version using your approach. However, I have a very big problem. All my discrete variables are updated in the callback function. However, their values will influence the integration in the next steps. But, as you proposed, the discrete values are not stored to be used in the following integration steps.

For example, the value S.Tc (which defined a torque) is computed inside the callback. However, in the next update step inside the dynamic function, S.Tc = [0;0;0]. I think I am missing some kind of operation with my type SimWorkspace. Can you help me?

Oh I think I had a typo in this. If you use the helper macro, you have to reference your current state as u. So:

sat_callback = @ode_callback begin
  if t in tstop
    # Apply discrete changes to `u` here
    # u.discrete_field1 = 2
    ...
  end
  @ode_savevalues # make sure you save after!
end

then pass this with the keyword callback = sat_callback. (One thing we’re working on is making this callback syntax simpler so there’s less “macro magic” involved).

Hi @ChrisRackauckas ! Thanks for the prompt answer! I was already using u. But, in this case, the discrete_field1 remains 0 (the initial value) in the solver.

Hmm, I know that mutating u inside of the callback will work because that’s how events are applied. Maybe there’s something more subtle going on. Would you mind sharing the callback code?

My guess for what’s happening is you may be changing the reference somewhere?

function f!(u)
  u = [1.0]
  nothing
end 

would make it so that way f! doesn’t actually mutate u since it changes the reference. This kind of thing can lead to a subtle bug which has this kind of “signature”. If I had to guess what could be going on, that would be the first think I’d check.

The code is pretty simple now:

sat_callback = @ode_callback begin
  if t in tstop
    u.Tc = [0.5;0.5;0.5]
  end
  @ode_savevalues # make sure you save after!
end

I think I know what is the problem, but I could not find how to solve it. I created a new type SimWorkspace and it seems that the solver does a lot operations between this type and arrays. It works because SimWorkspace is a subtype of AbstractArray and I defined all the methods to index the components.

However, let’s say that s0.Tc = [0.5;0.5;0.5]. If I do an operation with an array: s1 = s0.*a for example, then the continuous variables in x will be updated correctly. But the discrete values in s1 will be the defaults and not the ones in s0.

I have tried to define many operations between arrays and SimWorkspace but I could not fix it yet.

Notice that if I modify a value in u.x it indeed persists for the next operation.

Yes, this is because it only broadcasts over the indexed variables. One way to handle this is to do the following:

s1 = deepcopy(s0)
s1 .= s0.*a

This will make all of the fields of s1 match s0, and then update the indexed ones (the continuous ones).

In general though you shouldn’t be needing to allocate new SimWorkspace types, just mutating old ones. Would you mind sharing where this occurs? There’s probably a more efficient way to handle it that shouldn’t have much more difficult syntax.

I am not doing this. My guess is that somewhere inside the solver code the references to the non-indexed fields are lost. I will try to do a minimal working example.

For example, the solver calls utilde = similar(u). Can you tell me how can I define the similar function for this new type?

Ah! I also need to define the function recursivecopy!.

Hi @ChrisRackauckas

Well, I have an update. The non-indexed fields are carrying what was defined in the initial value in s0. The modifications inside the callback are not working.

Oh I know the problem. The problem isn’t the similar: that just creates the caches. Indeed, you found the problem is the update loops. They tend to look like this:

      for i in eachindex(u)
        tmp[i] = u[i]+dt*(a41*k1[i]+a42*k2[i]+a43*k3[i])
      end

which gives exactly the problem you described (updates the indices, but not the discrete variables). The solution is to update the discrete variables in the cached variables like tmp when you do the discrete update. A tuple with all of the caches is provided in the callback and is named cache. So you just need to do the following:

sat_callback = @ode_callback begin
  if t in tstop
    for c in cache # updates the discrete variables on `u` and all caches
       c.Tc = [0.5;0.5;0.5]
    end
  end
  @ode_savevalues # make sure you save after!
end

(Edit: I’m dumb, u is already in the cache tuple, so you just need to loop over that)

Excellent! This indeed solved my problem! Is it a bug in package? Should I open a bug issue somewhere?

It’s not a bug as much as it is something that should have a better API and documentation. Would you mind opening an issue in the DifferentialEquations.jl issue about hybrid differential equations?

I think I’d like to create a HybridDiffEq.jl package to make this all easier. As demonstrated here, you can do it all by hand, but it wouldn’t be hard to make some macros so that way you just

  • Declare the continuous variables
  • Declare the discrete variables
  • Declare the ODE
  • Declare the discrete updates

and it makes this. Syntactic sugar + some documentation could make this more user-friendly.