Hello Flux Experts /Users
I have been working recently with Flux.jl in Julia and have been trying to code an ANN that uses strain (from FEM as input to the ANN) and then calculates the residual based on the output of the ANN (as seen in the code, below). Somehow I am unable to perform the training, all I want to do is make the residual go to zero. I did the exact same way as I’d write in tensor flow. But the Flux.train! is something that I am new to, and unable to figure out what to write in there.
Any leads/help is highly appreciated! Thanks in advance.
# Parameters
E = 1.0 # Young's Modulus
num_elem = 5
bar_length = 1.0
# Solve linear elasticity
disp, stress , strains, quad_points= solve_linear_elasticity(E, num_elem, bar_length)
println("----------------------------------------------------------------")
println("Nodal Displacements evaluated")
println("Quadrature Points collected")
println("Strains and Stresses at quadrature points evaluted")
# Data Pre-processing
# Quadrature Points: Use vcat if you want to convert quadpoints into a single vector : vcat(quad_points...)
# Stress Values: Also can be done with the same setup , but it comes as column vector instead of a row vector
# So maybe use adjoint(vcat(stress...))
function doassemble(cellvalues_u, facevalues_u, grid, dh, stress_predicted)
nu = getnbasefunctions(cellvalues_u)
total_residual = zeros(Float64, ndofs(dh)) # total residual vector
for cell in CellIterator(dh)
re = zeros(Float64, nu) # element residual vector
assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)
# Add element residual to total_residual using a loop
cell_dofs = celldofs(cell)
for i in 1:nu
total_residual[cell_dofs[i]] += re[i]
end
end
return total_residual
end
function assemble_linear_elasticity!(re, cell, cellvalues_u, facevalues_u, grid, stress_predicted)
n_basefuncs_u = getnbasefunctions(cellvalues_u)
reinit!(cellvalues_u, cell)
for q_point in 1:getnquadpoints(cellvalues_u)
dΩ = getdetJdV(cellvalues_u, q_point)
for i in 1:n_basefuncs_u
∇δN = shape_symmetric_gradient(cellvalues_u, q_point, i)
for dim in eachindex(∇δN)
re[dim] += ∇δN[dim] * stress_predicted[i] * dΩ
end
end
end
for face in 1:nfaces(cell)
if onboundary(cell, face) && (cellid(cell), face) ∈ getfaceset(grid, "traction")
reinit!(facevalues_u, cell, face)
for q_point in 1:getnquadpoints(facevalues_u)
dΓ = getdetJdV(facevalues_u, q_point)
for i in 1:n_basefuncs_u
δu = shape_value(facevalues_u, q_point, i)
for dim in eachindex(δu)
prescribed_traction = 0.01
re[dim] += - δu[dim] * prescribed_traction * dΓ
#re[dim] += - δu[dim] * stress_predicted[i] * dΓ
end
end
end
end
end
end
function create_ANN(input_dim)
model = Chain(
Dense(input_dim, 64, tanh),
Dense(64, input_dim)
)
return model
end
using Flux
using Plots
function residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)
grid = create_1d_grid(num_elem, bar_length)
interpolation = Lagrange{1, RefCube, 1}()
dh = create_dofhandler(grid, interpolation)
dbc = create_bc(dh)
linear = Lagrange{1, RefCube, 1}()
cellvalues_u, facevalues_u, qr = create_values(linear)
# Convert to a single array
input_data = vcat([vec(mat) for mat in strains]...)
model = create_ANN(size(input_data)[1])
# Define the residual function
function residual(x)
return doassemble(cellvalues_u, facevalues_u, grid, dh, model(x))
end
# Loss function: mean of sum of squares of the residual
function loss(x)
R = residual(x)
return mean(R.^2)
end
# Optimizer (Adam)
opt = ADAM(lr)
# Training loop
for epoch in 1:epochs
Flux.train!(loss, Flux.params(model), [(input_data,)], opt)
print("Epoch: ", epoch)
print(" ")
print("Loss: ", loss(input_data))
println("--------------------------------")
end
stress_predicted = model(input_data)
stress_target = vcat([vec(mat) for mat in stress]...)
# Create a plot
plot(input_data, stress_predicted, label="Predicted Stress", xlabel="Strain", ylabel="Stress", linewidth=2)
plot!(input_data, stress_target, label="Target Stress", linestyle=:dash, linewidth=2)
# Set the output folder path
output_folder = "/Users/gagankaushikmanyam/Desktop/Projects/DFG-KIStoff/RESULTS/UniAxial/1D"
# Create the folder if it doesn't exist
if !isdir(output_folder)
mkdir(output_folder)
end
# Set the full path for saving the PNG file
output_file = joinpath(output_folder, "plot_result.png")
# Save the plot as a PNG file
savefig(output_file)
println("Plot saved at: $output_file")
end
residual_ANN(num_elem, bar_length; epochs=10, lr=0.01)