Hi, I utilize Lux.jl implemented the input convex neural network, but it doesn’t show positive eigenvalues at origin. Below is the working example. Any comment is appreciated.
using Lux
using Random
using LinearAlgebra
using ForwardDiff
function huber_nonneg(x)
    return @. ifelse(abs(x) < 1.0, 0.5 * x^2, abs(x) - 0.5)
end
"""
    ICNNLayer
Input Convex Neural Network layer with two weight matrices:
- W_z: non-negative weights from previous layer (enforces convexity)
- W_x: unconstrained weights for input passthrough
"""
struct ICNNLayer{T, F} <: Lux.AbstractLuxLayer
    in_dims::Int
    hidden_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
end
function ICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int, 
                   activation=softplus; use_bias=true, T::Type=Float64)
    return ICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias)
end
function Lux.initialparameters(rng::AbstractRNG, layer::ICNNLayer{T,F}) where {T,F}
    W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
    W_x = randn(rng, T, layer.out_dims, layer.in_dims)
    
    if layer.use_bias
        bias = zeros(T, layer.out_dims)
        return (W_z=W_z, W_x=W_x, bias=bias)
    else
        return (W_z=W_z, W_x=W_x)
    end
end
Lux.initialstates(rng::AbstractRNG, ::ICNNLayer) = NamedTuple()
function Lux.parameterlength(layer::ICNNLayer)
    n = layer.out_dims * layer.hidden_dims + layer.out_dims * layer.in_dims
    return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::ICNNLayer) = 0
function (layer::ICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
    y = ps.W_z * z + ps.W_x * x
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end
"""
    ICNNFirstLayer
First layer of ICNN (no previous hidden layer)
"""
struct ICNNFirstLayer{T,F} <: Lux.AbstractLuxLayer
    in_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
end
function ICNNFirstLayer(in_dims::Int, out_dims::Int, 
                        activation=softplus; use_bias=true, T::Type=Float64)
    return ICNNFirstLayer{T, typeof(activation)}(in_dims, out_dims, activation, use_bias)
end
function Lux.initialparameters(rng::AbstractRNG, layer::ICNNFirstLayer{T,F}) where {T,F}
    weight = randn(rng, T, layer.out_dims, layer.in_dims)
    
    if layer.use_bias
        bias = zeros(T, layer.out_dims)
        return (weight=weight, bias=bias)
    else
        return (weight=weight,)
    end
end
Lux.initialstates(rng::AbstractRNG, ::ICNNFirstLayer) = NamedTuple()
function Lux.parameterlength(layer::ICNNFirstLayer)
    n = layer.out_dims * layer.in_dims
    return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::ICNNFirstLayer) = 0
function (layer::ICNNFirstLayer{T,F})(x, ps, st) where {T,F}
    y = ps.weight * x
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end
"""
    FinalICNNLayer
Final output layer of ICNN with optional quadratic terms
"""
struct FinalICNNLayer{T, F} <: Lux.AbstractLuxLayer
    in_dims::Int
    hidden_dims::Int
    out_dims::Int
    activation::F
    use_bias::Bool
    use_quadratic::Bool
end
function FinalICNNLayer(in_dims::Int, hidden_dims::Int, out_dims::Int=1,
                        activation=identity; use_bias=false, use_quadratic=false, T::Type=Float64)
    return FinalICNNLayer{T, typeof(activation)}(in_dims, hidden_dims, out_dims, activation, use_bias, use_quadratic)
end
function Lux.initialparameters(rng::AbstractRNG, layer::FinalICNNLayer{T,F}) where {T,F}
    W_z = rand(rng, T, layer.out_dims, layer.hidden_dims)
    
    if layer.use_quadratic
        W_x = randn(rng, T, layer.out_dims, layer.in_dims)
        d = rand(rng, T, layer.in_dims)
        
        if layer.use_bias
            bias = zeros(T, layer.out_dims)
            return (W_z=W_z, W_x=W_x, d=d, bias=bias)
        else
            return (W_z=W_z, W_x=W_x, d=d)
        end
    else
        if layer.use_bias
            bias = zeros(T, layer.out_dims)
            return (W_z=W_z, bias=bias)
        else
            return (W_z=W_z,)
        end
    end
end
Lux.initialstates(rng::AbstractRNG, ::FinalICNNLayer) = NamedTuple()
function Lux.parameterlength(layer::FinalICNNLayer)
    n = layer.out_dims * layer.hidden_dims
    if layer.use_quadratic
        n += layer.out_dims * layer.in_dims + layer.in_dims
    end
    return layer.use_bias ? n + layer.out_dims : n
end
Lux.statelength(::FinalICNNLayer) = 0
function (layer::FinalICNNLayer{T,F})((x, z)::Tuple, ps, st) where {T,F}
    y = ps.W_z * z
    
    if layer.use_quadratic
        quad_term = sum(x .^ 2 .* huber_nonneg(ps.d), dims=1)
        linear_term = ps.W_x * x
        y = y .+ quad_term .+ linear_term
    end
    
    if layer.use_bias
        y = y .+ ps.bias
    end
    
    return layer.activation.(y), st
end
"""
    ICNNChain
Chain of ICNN layers with proper input/hidden state passing
"""
struct ICNNChain{L<:Tuple} <: Lux.AbstractLuxLayer
    layers::L
end
ICNNChain(layers...) = ICNNChain(layers)
function Base.show(io::IO, chain::ICNNChain)
    print(io, "ICNNChain(")
    for (i, layer) in enumerate(chain.layers)
        i > 1 && print(io, ", ")
        print(io, layer)
    end
    print(io, ")")
end
function Lux.initialparameters(rng::AbstractRNG, chain::ICNNChain)
    params = map(chain.layers) do layer
        Lux.initialparameters(rng, layer)
    end
    return (layers=params,)
end
function Lux.initialstates(rng::AbstractRNG, chain::ICNNChain)
    states = map(chain.layers) do layer
        Lux.initialstates(rng, layer)
    end
    return (layers=states,)
end
function Lux.parameterlength(chain::ICNNChain)
    return sum(Lux.parameterlength, chain.layers)
end
function Lux.statelength(chain::ICNNChain)
    return sum(Lux.statelength, chain.layers)
end
function (chain::ICNNChain)(x, ps, st)
    # First layer - only takes x
    z, st_first = chain.layers[1](x, ps.layers[1], st.layers[1])
    states = [st_first]
    
    # Subsequent layers - take both x and previous hidden state
    for i in 2:length(chain.layers)
        z, st_new = chain.layers[i]((x, z), ps.layers[i], st.layers[i])
        push!(states, st_new)
    end
    
    return z, (layers=tuple(states...),)
end
"""
    create_icnn(n_vars, hidden_dims=[32, 32]; T=Float64, rng=Random.default_rng())
Create an Input Convex Neural Network for entropy function
"""
function create_icnn(n_vars::Int, hidden_dims::Vector{Int}=[32, 32]; 
                     T::Type=Float64, rng::AbstractRNG=Random.default_rng())
    @assert !isempty(hidden_dims) "hidden_dims cannot be empty"
    
    layers = []
    
    # First layer
    push!(layers, ICNNFirstLayer(n_vars, hidden_dims[1], softplus; T=T))
    
    # Hidden layers
    for i in 1:(length(hidden_dims)-1)
        push!(layers, ICNNLayer(n_vars, hidden_dims[i], 
                                hidden_dims[i+1], softplus; T=T))
    end
    
    # Output layer (no activation, with quadratic terms)
    push!(layers, FinalICNNLayer(n_vars, hidden_dims[end], 1, identity; 
                                 use_bias=false, use_quadratic=false, T=T))
    
    model = ICNNChain(layers...)
    ps, st = Lux.setup(rng, model)
    
    return model, ps, st
end
rng = Random.default_rng()
model, ps, st = create_icnn(2, [16, 16]; T=Float64, rng=rng)
u = Float64[0.0, 0.0]
    
# Forward pass
η, _ = model(u,ps, st)
println(η[1]) # 38.30048
    
# Hessian
H = ForwardDiff.hessian(u_vec -> model(u_vec, ps, st)[1][1], u)
println(isapprox(H, H', rtol=1e-15))# false
println(all(eigvals(H) .> -1e-15))  # false