# PINN in Julia

Hello Flux Experts /Users

I have been working recently with Flux.jl in Julia and have been trying to code an ANN that uses strain (from FEM as input to the ANN) and then calculates the residual based on the output of the ANN (as seen in the code, below). Somehow I am unable to perform the training, all I want to do is make the residual go to zero. I did the exact same way as I’d write in tensor flow. But the Flux.train! is something that I am new to, and unable to figure out what to write in there.

``````# Parameters
E = 1.0  # Young's Modulus
num_elem = 5
bar_length = 1.0

# Solve linear elasticity
disp, stress , strains, quad_points= solve_linear_elasticity(E, num_elem, bar_length)

println("----------------------------------------------------------------")
println("Nodal Displacements evaluated")
println("Strains and Stresses at quadrature points evaluted")

# Data Pre-processing

# Quadrature Points: Use vcat if you want to convert quadpoints into a single vector : vcat(quad_points...)
# Stress Values: Also can be done with the same setup , but it comes as column vector instead of a row vector

function doassemble(cellvalues_u, facevalues_u, grid, dh, stress_predicted)
nu = getnbasefunctions(cellvalues_u)

total_residual = zeros(Float64, ndofs(dh))  # total residual vector

for cell in CellIterator(dh)
re = zeros(Float64, nu)  # element residual vector
assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)

# Add element residual to total_residual using a loop
cell_dofs = celldofs(cell)
for i in 1:nu
total_residual[cell_dofs[i]] += re[i]
end
end

end

function assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)
n_basefuncs_u = getnbasefunctions(cellvalues_u)

reinit!(cellvalues_u, cell)

dΩ = getdetJdV(cellvalues_u, q_point)

for i in 1:n_basefuncs_u
for dim in eachindex(∇δN)
re[dim] += ∇δN[dim] * stress_predicted[i] * dΩ
end
end
end

for face in 1:nfaces(cell)
if onboundary(cell, face) && (cellid(cell), face) ∈ getfaceset(grid, "traction")
reinit!(facevalues_u, cell, face)
dΓ = getdetJdV(facevalues_u, q_point)
for i in 1:n_basefuncs_u
δu = shape_value(facevalues_u, q_point, i)
for dim in eachindex(δu)
prescribed_traction = 0.01
re[dim] += - δu[dim] * prescribed_traction * dΓ
#re[dim] += - δu[dim] * stress_predicted[i] * dΓ
end
end
end
end
end
end

function create_ANN(input_dim)

model = Chain(
Dense(input_dim, 64, tanh),
Dense(64, input_dim)
)
return model
end

using Flux
using Plots

function residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)
grid = create_1d_grid(num_elem, bar_length)

interpolation = Lagrange{1, RefCube, 1}()
dh = create_dofhandler(grid, interpolation)
dbc = create_bc(dh)
linear = Lagrange{1, RefCube, 1}()

cellvalues_u, facevalues_u, qr = create_values(linear)

# Convert to a single array
input_data = vcat([vec(mat) for mat in strains]...)

model = create_ANN(size(input_data)[1])

# Define the residual function
function residual(x)
return doassemble(cellvalues_u, facevalues_u, grid, dh, model(x))
end

# Loss function: mean of sum of squares of the residual
function loss(x)
R = residual(x)
return mean(R.^2)
end

# Training loop
for epoch in 1:epochs
Flux.train!(loss, Flux.params(model), [(input_data,)], opt)
print("Epoch: ", epoch)
print("     ")
print("Loss: ", loss(input_data))
println("--------------------------------")
end

stress_predicted = model(input_data)
stress_target = vcat([vec(mat) for mat in stress]...)

# Create a plot
plot(input_data, stress_predicted, label="Predicted Stress", xlabel="Strain", ylabel="Stress", linewidth=2)
plot!(input_data, stress_target, label="Target Stress", linestyle=:dash, linewidth=2)

# Set the output folder path
output_folder = "/Users/gagankaushikmanyam/Desktop/Projects/DFG-KIStoff/RESULTS/UniAxial/1D"

# Create the folder if it doesn't exist
if !isdir(output_folder)
mkdir(output_folder)
end

# Set the full path for saving the PNG file
output_file = joinpath(output_folder, "plot_result.png")

# Save the plot as a PNG file
savefig(output_file)

println("Plot saved at: \$output_file")

end

residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)

``````