Input Convex Neural Network is not Convex at Origin in Lux.jl

Hi, I utilize Lux.jl implemented the input convex neural network, but it doesn’t show positive eigenvalues at origin. Below is the working example. Any comment is appreciated.


using Lux
using Random
using LinearAlgebra
using ForwardDiff


function huber_nonneg(x)
    return @. ifelse(abs(x) < 1.0, 0.5 * x^2, abs(x) - 0.5)
end

"""
    ICNNLayer
Input Convex Neural Network layer with two weight matrices:
- W_z: non-negative weights from previous layer (enforces convexity)
- W_x: unconstrained weights for input passthrough
"""
struct ICNNLayer{T, F} <: Lux.AbstractLuxLayer
    in_dims::Int
    hidden_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
end

function ICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int, 
                   activation=softplus; use_bias=true, T::Type=Float64)
    return ICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias)
end

function Lux.initialparameters(rng::AbstractRNG, layer::ICNNLayer{T,F}) where {T,F}
    W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
    W_x = randn(rng, T, layer.out_dims, layer.in_dims)
    
    if layer.use_bias
        bias = zeros(T, layer.out_dims)
        return (W_z=W_z, W_x=W_x, bias=bias)
    else
        return (W_z=W_z, W_x=W_x)
    end
end

Lux.initialstates(rng::AbstractRNG, ::ICNNLayer) = NamedTuple()

function Lux.parameterlength(layer::ICNNLayer)
    n = layer.out_dims * layer.hidden_dims + layer.out_dims * layer.in_dims
    return layer.use_bias ? n + layer.out_dims : n
end

Lux.statelength(::ICNNLayer) = 0

function (layer::ICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
    y = ps.W_z * z + ps.W_x * x
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end

"""
    ICNNFirstLayer
First layer of ICNN (no previous hidden layer)
"""
struct ICNNFirstLayer{T,F} <: Lux.AbstractLuxLayer
    in_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
end

function ICNNFirstLayer(in_dims::Int, out_dims::Int, 
                        activation=softplus; use_bias=true, T::Type=Float64)
    return ICNNFirstLayer{T, typeof(activation)}(in_dims, out_dims, activation, use_bias)
end

function Lux.initialparameters(rng::AbstractRNG, layer::ICNNFirstLayer{T,F}) where {T,F}
    weight = randn(rng, T, layer.out_dims, layer.in_dims)
    
    if layer.use_bias
        bias = zeros(T, layer.out_dims)
        return (weight=weight, bias=bias)
    else
        return (weight=weight,)
    end
end

Lux.initialstates(rng::AbstractRNG, ::ICNNFirstLayer) = NamedTuple()

function Lux.parameterlength(layer::ICNNFirstLayer)
    n = layer.out_dims * layer.in_dims
    return layer.use_bias ? n + layer.out_dims : n
end

Lux.statelength(::ICNNFirstLayer) = 0

function (layer::ICNNFirstLayer{T,F})(x, ps, st) where {T,F}
    y = ps.weight * x
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end

"""
    FinalICNNLayer
Final output layer of ICNN with optional quadratic terms
"""
struct FinalICNNLayer{T, F} <: Lux.AbstractLuxLayer
    in_dims::Int
    hidden_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
    use_quadratic::Bool
end

function FinalICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int=1,
                        activation=identity; use_bias=false, use_quadratic=false, T::Type=Float64)
    return FinalICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias, use_quadratic)
end

function Lux.initialparameters(rng::AbstractRNG, layer::FinalICNNLayer{T,F}) where {T,F}
    W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
    
    if layer.use_quadratic
        W_x = randn(rng, T, layer.out_dims, layer.in_dims)
        d = rand(rng, T, layer.in_dims)
        
        if layer.use_bias
            bias = zeros(T, layer.out_dims)
            return (W_z=W_z, W_x=W_x, d=d, bias=bias)
        else
            return (W_z=W_z, W_x=W_x, d=d)
        end
    else
        if layer.use_bias
            bias = zeros(T, layer.out_dims)
            return (W_z=W_z, bias=bias)
        else
            return (W_z=W_z,)
        end
    end
end

Lux.initialstates(rng::AbstractRNG, ::FinalICNNLayer) = NamedTuple()

function Lux.parameterlength(layer::FinalICNNLayer)
    n = layer.out_dims * layer.hidden_dims
    if layer.use_quadratic
        n += layer.out_dims * layer.in_dims + layer.in_dims
    end
    return layer.use_bias ? n + layer.out_dims : n
end

Lux.statelength(::FinalICNNLayer) = 0

function (layer::FinalICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
    y = ps.W_z * z
    
    if layer.use_quadratic
        quad_term = sum(x .^ 2 .* huber_nonneg(ps.d), dims=1)
        linear_term = ps.W_x * x
        y = y .+ quad_term .+ linear_term
    end
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end

"""
    ICNNChain
Chain of ICNN layers with proper input/hidden state passing
"""
struct ICNNChain{L<:Tuple} <: Lux.AbstractLuxLayer
    layers::L
end

ICNNChain(layers...) = ICNNChain(layers)

function Base.show(io::IO, chain::ICNNChain)
    print(io, "ICNNChain(")
    for (i, layer) in enumerate(chain.layers)
        i > 1 && print(io, ", ")
        print(io, layer)
    end
    print(io, ")")
end

function Lux.initialparameters(rng::AbstractRNG, chain::ICNNChain)
    params = map(chain.layers) do layer
        Lux.initialparameters(rng, layer)
    end
    return (layers=params,)
end

function Lux.initialstates(rng::AbstractRNG, chain::ICNNChain)
    states = map(chain.layers) do layer
        Lux.initialstates(rng, layer)
    end
    return (layers=states,)
end

function Lux.parameterlength(chain::ICNNChain)
    return sum(Lux.parameterlength, chain.layers)
end

function Lux.statelength(chain::ICNNChain)
    return sum(Lux.statelength, chain.layers)
end

function (chain::ICNNChain)(x, ps, st)
    # First layer - only takes x
    z, st_first = chain.layers[1](x, ps.layers[1], st.layers[1])
    states = [st_first]
    
    # Subsequent layers - take both x and previous hidden state
    for i in 2:length(chain.layers)
        z, st_new = chain.layers[i]((x, z), ps.layers[i], st.layers[i])
        push!(states, st_new)
    end
    
    return z, (layers=tuple(states...),)
end

"""
    create_icnn(n_vars, hidden_dims=[32, 32]; T=Float64, rng=Random.default_rng())
Create an Input Convex Neural Network for entropy function
"""
function create_icnn(n_vars::Int, hidden_dims::Vector{Int}=[32, 32]; 
                     T::Type=Float64, rng::AbstractRNG=Random.default_rng())
    @assert !isempty(hidden_dims) "hidden_dims cannot be empty"
    
    layers = []
    
    # First layer
    push!(layers, ICNNFirstLayer(n_vars, hidden_dims[1], softplus; T=T))
    
    # Hidden layers
    for i in 1:(length(hidden_dims)-1)
        push!(layers, ICNNLayer(n_vars, hidden_dims[i], 
                                hidden_dims[i+1], softplus; T=T))
    end
    
    # Output layer (no activation, with quadratic terms)
    push!(layers, FinalICNNLayer(n_vars, hidden_dims[end], 1, identity; 
                                 use_bias=false, use_quadratic=false, T=T))
    
    model = ICNNChain(layers...)
    ps, st = Lux.setup(rng, model)
    
    return model, ps, st
end


rng = Random.default_rng()
model, ps, st = create_icnn(2, [16, 16]; T=Float64, rng=rng)
u = Float64[0.0, 0.0]
    
# Forward pass
η, _ = model(u,ps, st)
println(η[1]) # 38.30048
    
# Hessian
H = ForwardDiff.hessian(u_vec -> model(u_vec, ps, st)[1][1], u)
println(isapprox(H, H', rtol=1e-15))# false
println(all(eigvals(H) .> -1e-15))  # false

I made a few changes, when the first layer uses relu, the Hessian at origin is SPD.

function create_icnn(n_vars::Int, hidden_dims::Vector{Int}=[32, 32]; 
                     T::Type=Float64, rng::AbstractRNG=Random.default_rng())
    @assert !isempty(hidden_dims) "hidden_dims cannot be empty"
    
    layers = []
    
    # First layer
    push!(layers, ICNNFirstLayer(n_vars, hidden_dims[1], relu; T=T))
    
    # Hidden layers
    for i in 1:(length(hidden_dims)-1)
        push!(layers, ICNNLayer(n_vars, hidden_dims[i], 
                                hidden_dims[i+1], softplus; T=T))
    end
    
    # Output layer (no activation, with quadratic terms)
    push!(layers, FinalICNNLayer(n_vars, hidden_dims[end], 1, identity; 
                                 use_bias=false, use_quadratic=true, T=T))
    
    model = ICNNChain(layers...)
    ps, st = Lux.setup(rng, model)
    
    return model, ps, st
end
julia> include("examples/mwe.jl")
6.539322516230143
true
true

It somehow is a bug?