Try changing that line to tstate = main(tstate, vjp_rule, (reshape(x, 1, 50), reshape(y, 1, 50)), 10000)
1 Like
Try changing that line to tstate = main(tstate, vjp_rule, (reshape(x, 1, 50), reshape(y, 1, 50)), 10000)