Hi, I have been using Julia for a while but new to the Julia ML ecosystem. I want to rewrite a lease square monte carlo model with neural nets using `Lux`

or `Flux`

.

Here is the main code from python’s `jax`

, basically it would generate the path for an American option and calculate the loss for each exercise day and finally find the best \Theta that could well predict the price for every exercise day:

```
import jax.numpy as jnp
import jax
import haiku as hk
import optax
optimizer = optax.adam
lr = 1e-4
Spot = jnp.array([38, 36, 35]) # stock price
σ = jnp.array([0.2, .25, .3]) # stock volatility
K = 40 # strike price
r = 0.06 # risk free rate
n = 100000 # Number of simualted paths
batch_size = 512
m = 50 # number of exercise dates
T = 1 # maturity
Δt = T / m # interval between two exercise dates
n_stocks = 3
# simulates one step of the stock price evolution
def step(S, rng):
ϵ = jax.random.normal(rng, S.shape)
dB = jnp.sqrt(Δt) * ϵ
S = S + r * S * Δt + σ * S * dB
return S, S # it returns two args is to use jax.lax.scan function
def payoff_put(S):
return jnp.maximum(K - S, 0.)
def model(Si):
# a simple nn model
Si = jnp.column_stack([Si])
out = (Si - 36.) / 5
out = hk.Linear(64)(out)
out = jax.nn.relu(out)
out = hk.Linear(64)(out)
out = jax.nn.relu(out)
out = hk.Linear(1)(out)
out = jnp.squeeze(out)
return out
init, model = hk.without_apply_rng(hk.transform(model))
rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)
# LSMC algorithm
def compute_price(Θ, batch_size, rng):
S = jnp.column_stack([jnp.ones(batch_size) * Spot[i] for i in range(3)]) # (batch_size, 3)
rng_vector = jax.random.split(rng, m) # (m, 1)
_, S = jax.lax.scan(step, S, rng_vector) # (m, batch_size, 3)
discount = jnp.exp(-r * Δt)
# Very last date
value_if_exercise = payoff_put(S[-1])
discounted_future_cashflows = value_if_exercise * discount
def core(state, input):
discounted_future_cashflows = state
Si, Θi = input
Y = discounted_future_cashflows
value_if_wait = model(Θi, Si)
mse = jnp.mean((value_if_wait - discounted_future_cashflows)**2)
value_if_exercise = payoff_put(Si)
exercise = value_if_exercise >= value_if_wait
discounted_future_cashflows = discount * jnp.where(
exercise,
value_if_exercise,
discounted_future_cashflows)
return discounted_future_cashflows, mse
# Proceed recursively
S = jnp.flip(S, 0)[1:]
inputs = S, Θ
discounted_future_cashflows, mse = jax.lax.scan(core, discounted_future_cashflows, inputs) #here inputs are tuple of `jax ndarray` and `nested dictionary` that are supported by jax
return discounted_future_cashflows.mean(), mse
```

My main question is regarding the set up of parameters and optimizer state(the code below), we could define model with `Lux.Chain`

or `Flux.Chain`

. (I don’t not how to do it in Pytorch as well)

```
Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)
```

PS: if you need the full code, I could leave a colab link here.

Really appreciate the community!