Differentiable vectorized piecewise function in Julia?

How could I define this function in Julia?

f(x) = \begin{cases} x^2 &\text {x < 2} \\ (x-2)(x-4) &\text {otherwise}\end{cases}

function f(x)
    if x < 2
        return x^2
    else
        return (x-2)*(x-4)
    end
end

?

2 Likes

Thank you. But I said vectorized and differentiable. Its input must be a vector or matrix.

Linking related thread that seems to address the differentiable part.

As for the vectorization, f.(x) should work above?

just add . to make it working on mult-dim data (or just use f.())

using Zygote

function f(x)
    if all(x .< 2)
        return x.^2
    else
        return (x.-2).*(x.-4)
    end
end


jacobian(f, [1.0, 2.0, 3.0])
1 Like

I do not understand what you want. For me this is ok :

julia> function f(x)
           if x < 2
               return x^2
           else
               return (x-2)*(x-4)
           end
       end
f (generic function with 1 method)

julia> x = [1,2,3]
3-element Vector{Int64}:
 1
 2
 3

julia> f.(x)
3-element Vector{Int64}:
  1
  0
 -1

julia> using ForwardDiff

julia> ForwardDiff.derivative.(f,x)
3-element Vector{Int64}:
  2
 -2
  0

julia> 
4 Likes

This is unclear to me, as the condition “x > 2” is undefined for a vector or matrix.

2 Likes