Vectorization and passing a matrix as an argument


#1

I have an energy type function that I’ve written:

function energy(r,A)
    return 0.5 * dot(r, A*r)
end

where r is an n dimensional array and A is an nxn matrix which may or may not be sparse. Evaluating this on a single r and A works fine. However, for a fixed A, if I try to evaluate on many r's via vectorization, I get the error:

ERROR: DimensionMismatch("arrays could not be broadcast to a common size")

Here is a complete sample code that produces this error:

 function energy(r,A)
       return 0.5 * dot(r, A*r)
end

A = [1. 2.; 2. 1.]
r = [1.,0.5]
energy(r,A) # this works fine

x = linspace(-1,1, 10)|>collect
y = linspace(-1,1, 10)|>collect

rvals = [[x[j],y[i]] for i=1:length(y), j = 1:length(x)];

energy.(rvals,A) # this generates the error

#2

Julia is interpreting your command as an attempt to broadcast over the elements of rvals and the elements of A simultaneously, but you actually want it to use the entirety of A for every call. In essence you want rvals to be treated like an array in the broadcast, but for A to be treated like a scalar so that it is passed intact to each call to energy. You can make any item behave like a scalar in broadcasting by simply wrapping it in a one-element tuple:

julia> energy.(rvals, (A,))
10×10 Array{Float64,2}:
  3.0        2.35802    1.76543    1.22222     0.728395    0.283951   -0.111111   -0.45679   -0.753086  -1.0     
  2.35802    1.81481    1.32099    0.876543    0.481481    0.135802   -0.160494   -0.407407  -0.604938  -0.753086
  1.76543    1.32099    0.925926   0.580247    0.283951    0.037037   -0.160494   -0.308642  -0.407407  -0.45679 
  1.22222    0.876543   0.580247   0.333333    0.135802   -0.0123457  -0.111111   -0.160494  -0.160494  -0.111111
  0.728395   0.481481   0.283951   0.135802    0.037037   -0.0123457  -0.0123457   0.037037   0.135802   0.283951
  0.283951   0.135802   0.037037  -0.0123457  -0.0123457   0.037037    0.135802    0.283951   0.481481   0.728395
 -0.111111  -0.160494  -0.160494  -0.111111   -0.0123457   0.135802    0.333333    0.580247   0.876543   1.22222 
 -0.45679   -0.407407  -0.308642  -0.160494    0.037037    0.283951    0.580247    0.925926   1.32099    1.76543 
 -0.753086  -0.604938  -0.407407  -0.160494    0.135802    0.481481    0.876543    1.32099    1.81481    2.35802 
 -1.0       -0.753086  -0.45679   -0.111111    0.283951    0.728395    1.22222     1.76543    2.35802    3.0     

Problem with split.(...)