I wonder, if there is a way in Julia to “peel away” an outer function calling an inner function
julia> f(x)=sqrt(abs2(x))
The task would be to generate a function g(x)=abs2(x) by only having access to f and knowing that the outer part is in fact sqrt(). I would like to “peel” this function, since this could help me with overwriting the broadcast mechanism for an array type to avoid calculating first the square-root of something known to be positive (radius of center), just to then square it again.
Can this be done in Julia? Any help would be appreciated.
I’m going to go ahead and say “no”. There might technically be some shenanigans you can do, but it’s going to produce illegible code. If the cost of squaring the number is too high, you better get access to the inner function.
You could generate the symbolic expression of the function and manipulate that yourself:
julia> using ModelingToolkit
julia> @variables x
(x,)
julia> f(x) = sqrt(abs2(x))
f (generic function with 1 method)
julia> expression = f(x)
sqrt(abs2(x))
julia> dump(expression)
Operation
op: sqrt (function of type typeof(sqrt))
args: Array{Expression}((1,)) Expression[abs2(x)]
Interesting approach using symbolic variables. The problem is that this violates the initial idea that only access to an already existing f is provided. I was hoping that Julia internally also stores something along the same way as your ‘dump’ function can show.
If you only have access to the functionf and not the expression of f, then there is nothing you can do using high-level code. People may probably be able to access the internals with low-level representations (e.g. lowered) but that would be a big waste of energy for a problem that could potentially be solved differently. Maybe you could share your motivations behind trying to access the inner function?
Notice that if you had access to the expression of f you could manipulate it without any package:
julia> f = :(sqrt(abs2(x)))
:(sqrt(abs2(x)))
julia> dump(f)
Expr
head: Symbol call
args: Array{Any}((2,))
1: Symbol sqrt
2: Expr
head: Symbol call
args: Array{Any}((2,))
1: Symbol abs2
2: Symbol x
Ok, here some more on the motivation: @roflmaostc and I wrote a package called IndexFunArrays.jl which allows you to very conveniently work with expressions on indices that look like arrays but are behave lazy. One such convenient function is rr(sizeTuple). It really is a shorthand for sqrt.(rr2()). All such functions have a bunch of possible convenient ways of using them which comes down to various ways of defining the offset and scale parameters. However, these user parameters are currently NOT stored in the array structure but are molded into the generate function which is stored by the array.
As many users naively may write things like rr((100,100)) .< 20.4 and are probably not aware that rr2((100,100)) .< abs2(20.4) would be a lot faster, I thought about overwriting the corresponding broadcast function such that it removes the completely unnecessary sqrt(). There are similar arguments to be made for users writing abs2.(rr((100,100))) which should be replaced by rr2((100,100)). Anyway, a solution would be to store offset and scale as this information would be enough to synthesize the necessary position conversion which is stored in the generator. But then I would prefer not to bloat the array type with such unnecessary information.
Using symbolic variables is the only way to do this at a high level. It doesn’t violate the original idea as long as the functions are written in Julia.
You pass symbolic variables in and you get a symbolic form of the function back, provided various technical requirements are met. This is often called “tracing” through the function.
rr() generates the ND distance to the center (as specified by an optional offset). rr2() generates the square of the distance to the center, i.e. it does not require a square-root. Sorry, your suggestion misses the point of trying to fuse functions in dependence of the usage. However, one can hope that at least one of the cases sqrt(abs2()) may be picked up lowlevel and converted to abs. Yet I doubt that the compiler can pick up the other case with the comparison.
I am not sure if this can really help here. The array must look like an array in all aspects, yet internally it stores a function. In the inner loop the getindex resolves the calculation. I suspect “resolving” the symbolic variable there may be a bad idea. Of course, the symbolic expression could be stored in addition, but then I can also store offset and scale…
I don’t understand this. So you are using rr and rr2 from some other package? And why is this sqrt(abs2(x)) cancellation something a user would expect your array type to do, rather than an optimization in the package that defines rr (or something they would do manually)?
If you know x to be positive, it might be worthwhile to make that information available (e.g. via a type parameter) to dispatch and write a abs2 and sqrt function that just pass those types through. Could also be done via a small wrapper.
E.g.
struct KnownPositive{T}
x::T
end
abs2(m::MyMatrix) = ispositive(m) ? abs2(KnownPositive(m)) : real_abs2(m)
abs2(k::KnownPositive) = k
sqrt(k::KnownPositive) = k.x
You need ownership/control over the type of your matrix though, which I assume you have since you mention customizing broadcasting. You’ll also have to customize other parts of your code, depending on where the result of abs2 ends up.
Yes. This boils down to storing the unpeeled function explicitly. It would work but is not the solution of choice for me. Storing the offsets and scales is then a more generic option, I guess.
Not really. abs2 hast to take the square and the sqrt has to annihilate with the abs2. but using a positive type is the right choice, if one wanted to be generic here. However, maybe also at the lowering level Julia could do something about such a mechanism to speed up its performance? Maybe it already does so?
There’s nothing special about sqrt and abs2 - they are just regular functions. Eliminating those specific constructs would mean special casing them, which (apart from the fact that it depends on the types in question whether it’s allowed in the first place) would require making lowering more complicated for questionable general gain.
Generalising this behaviour to eliminate “redundant” function application would require incorporating something like Symbolics.jl into the julia compiler, which would increase compile times dramatically (and you’d also have to have all those transforms defined for different types where it’s valid and make sure you don’t make a mistake - e.g. sqrt(abs2(1//21)) == 1//21 is false).
You basically need the wrapper type to decide this via dispatch at all, since for various edge cases in terms of value, the transform is not always valid:
Good point. I guess this has to do with the user choice IEEE float performance or fast optimization.
I have to admit that I know next to nothing about julia’s or LLVM internal optimizations.
If I write 1 .*A .+ 5 .+ 2 .*A .+7 with array type A does it internally get transformed to 12 .+ 3 .*A?
If not, maybe packages like Tullio.jl can do this?
If yes, then such optimizations should have similar issues. I think the user would in almost all cases want the performant version and the case julia> prevfloat(Inf) |> abs2 |> sqrt is rather a problem than a solution.
I heard rumors that a for-loop adding a number is sometimes entirely replaced by a simple multiplication by LLVM. If this is the case, than this also violates your implicit suggestion to obtain unmodified results at all costs.
No, though depending on the type of A, A may be replaced by a constant if the creator of that type allows for that specialization (for example if A is not a materialized matrix at all but an iterator behaving like a matrix).
That transform again requires something like Symbolics.jl and very specific semantics to be allowed to do (remember, * is also just a regular function and may have side effects for some custom types).
Yes, a macro could be written to do this, since what you’re asking for is a syntactic transform (though that again is limited to code the macro sees - @remove_redundancy sqrt(abs2(x)) could just spit out x, whereas @remove_redundancy f(x) with f(x) = sqrt(abs2(x)) can not, since the macro can’t see “into” f).
Indeed they have, though that’s not what Tullio.jl is doing specifically. Those transforms are basically what --fastmath is doing in some places.
Those transforms are few and far between and only applied if all involved behaviour can’t be distinguished (in terms of returned values). You’re probably thinking of the compiler replacing a naive sum over 1:n with the summation formula (n*(n+1))/2 or even the resulting constant, if n is known as well. Since they’re only applied if you can’t distinguish them, the results (in that sense) are “unmodified”. If I recall correctly, that specific transform is only applied for integer arguments, since floating point summation is not associative and you will get a different result.
This is not exclusive to LLVM, by the way - GCC does the same kinds of transforms, as do most sufficiently advanced compilers.