PINN in Julia

Hello Flux Experts /Users

I have been working recently with Flux.jl in Julia and have been trying to code an ANN that uses strain (from FEM as input to the ANN) and then calculates the residual based on the output of the ANN (as seen in the code, below). Somehow I am unable to perform the training, all I want to do is make the residual go to zero. I did the exact same way as I’d write in tensor flow. But the Flux.train! is something that I am new to, and unable to figure out what to write in there.

Any leads/help is highly appreciated! Thanks in advance.

# Parameters
E = 1.0  # Young's Modulus
num_elem = 5
bar_length = 1.0

# Solve linear elasticity
disp, stress , strains, quad_points= solve_linear_elasticity(E, num_elem, bar_length)

println("Nodal Displacements evaluated")
println("Quadrature Points collected")
println("Strains and Stresses at quadrature points evaluted")

# Data Pre-processing 

# Quadrature Points: Use vcat if you want to convert quadpoints into a single vector : vcat(quad_points...)
# Stress Values: Also can be done with the same setup , but it comes as column vector instead of a row vector
# So maybe use adjoint(vcat(stress...))

function doassemble(cellvalues_u, facevalues_u, grid, dh, stress_predicted)
    nu = getnbasefunctions(cellvalues_u)

    total_residual = zeros(Float64, ndofs(dh))  # total residual vector
    for cell in CellIterator(dh)
        re = zeros(Float64, nu)  # element residual vector
        assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)
        # Add element residual to total_residual using a loop
        cell_dofs = celldofs(cell)
        for i in 1:nu
            total_residual[cell_dofs[i]] += re[i]

    return total_residual

function assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)
    n_basefuncs_u = getnbasefunctions(cellvalues_u)

    reinit!(cellvalues_u, cell)
    for q_point in 1:getnquadpoints(cellvalues_u)
        dΩ = getdetJdV(cellvalues_u, q_point)

        for i in 1:n_basefuncs_u
            ∇δN = shape_symmetric_gradient(cellvalues_u, q_point, i)
            for dim in eachindex(∇δN)
                re[dim] += ∇δN[dim] * stress_predicted[i] * dΩ

    for face in 1:nfaces(cell)
        if onboundary(cell, face) && (cellid(cell), face) ∈ getfaceset(grid, "traction")
            reinit!(facevalues_u, cell, face)
            for q_point in 1:getnquadpoints(facevalues_u)
                dΓ = getdetJdV(facevalues_u, q_point)
                for i in 1:n_basefuncs_u
                    δu = shape_value(facevalues_u, q_point, i)
                    for dim in eachindex(δu)
                        prescribed_traction = 0.01
                        re[dim] += - δu[dim] * prescribed_traction * dΓ
                        #re[dim] += - δu[dim] * stress_predicted[i] * dΓ

function create_ANN(input_dim)

    model = Chain(
            Dense(input_dim, 64, tanh),
            Dense(64, input_dim)
    return model

using Flux
using Plots

function residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)
    grid = create_1d_grid(num_elem, bar_length)

    interpolation = Lagrange{1, RefCube, 1}()
    dh = create_dofhandler(grid, interpolation)
    dbc = create_bc(dh)
    linear = Lagrange{1, RefCube, 1}()

    cellvalues_u, facevalues_u, qr = create_values(linear)

    # Convert to a single array
    input_data = vcat([vec(mat) for mat in strains]...)

    model = create_ANN(size(input_data)[1])

    # Define the residual function
    function residual(x)
        return doassemble(cellvalues_u, facevalues_u, grid, dh, model(x))

    # Loss function: mean of sum of squares of the residual
    function loss(x)
        R = residual(x)
        return mean(R.^2)

    # Optimizer (Adam)
    opt = ADAM(lr)

    # Training loop
    for epoch in 1:epochs
        Flux.train!(loss, Flux.params(model), [(input_data,)], opt)
        print("Epoch: ", epoch) 
        print("     ")
        print("Loss: ", loss(input_data))

    stress_predicted = model(input_data)
    stress_target = vcat([vec(mat) for mat in stress]...)

    # Create a plot
    plot(input_data, stress_predicted, label="Predicted Stress", xlabel="Strain", ylabel="Stress", linewidth=2)
    plot!(input_data, stress_target, label="Target Stress", linestyle=:dash, linewidth=2)

    # Set the output folder path
    output_folder = "/Users/gagankaushikmanyam/Desktop/Projects/DFG-KIStoff/RESULTS/UniAxial/1D"

    # Create the folder if it doesn't exist
    if !isdir(output_folder)

    # Set the full path for saving the PNG file
    output_file = joinpath(output_folder, "plot_result.png")

    # Save the plot as a PNG file

    println("Plot saved at: $output_file")


residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)