# How to rewrite a jax model with stacked parameters

Hi, I have been using Julia for a while but new to the Julia ML ecosystem. I want to rewrite a lease square monte carlo model with neural nets using `Lux` or `Flux`.

Here is the main code from python’s `jax`, basically it would generate the path for an American option and calculate the loss for each exercise day and finally find the best \Theta that could well predict the price for every exercise day:

``````import jax.numpy as jnp
import jax
import haiku as hk
import optax

lr = 1e-4

Spot = jnp.array([38, 36, 35])   # stock price
σ = jnp.array([0.2, .25, .3])     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
batch_size = 512
m = 50      # number of exercise dates
T = 1       # maturity
Δt = T / m  # interval between two exercise dates
n_stocks = 3
# simulates one step of the stock price evolution
def step(S, rng):
ϵ = jax.random.normal(rng, S.shape)
dB = jnp.sqrt(Δt) * ϵ
S = S + r * S * Δt + σ * S * dB
return S, S  # it returns two args is to use jax.lax.scan function

def payoff_put(S):
return jnp.maximum(K - S, 0.)

def model(Si):
# a simple nn model
Si = jnp.column_stack([Si])
out = (Si - 36.) / 5
out = hk.Linear(64)(out)
out = jax.nn.relu(out)
out = hk.Linear(64)(out)
out = jax.nn.relu(out)
out = hk.Linear(1)(out)
out = jnp.squeeze(out)
return out

init, model = hk.without_apply_rng(hk.transform(model))

rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)

# LSMC algorithm
def compute_price(Θ, batch_size, rng):
S = jnp.column_stack([jnp.ones(batch_size) * Spot[i] for i in range(3)]) # (batch_size, 3)
rng_vector = jax.random.split(rng, m) # (m, 1)
_, S = jax.lax.scan(step, S, rng_vector) # (m, batch_size, 3)

discount = jnp.exp(-r * Δt)

# Very last date
value_if_exercise = payoff_put(S[-1])
discounted_future_cashflows = value_if_exercise * discount

def core(state, input):
discounted_future_cashflows = state
Si, Θi = input
Y = discounted_future_cashflows
value_if_wait = model(Θi, Si)

mse = jnp.mean((value_if_wait - discounted_future_cashflows)**2)
value_if_exercise = payoff_put(Si)
exercise = value_if_exercise >= value_if_wait
discounted_future_cashflows = discount * jnp.where(
exercise,
value_if_exercise,
discounted_future_cashflows)

return discounted_future_cashflows, mse

# Proceed recursively
S = jnp.flip(S, 0)[1:]
inputs = S, Θ
discounted_future_cashflows, mse = jax.lax.scan(core, discounted_future_cashflows, inputs) #here inputs are tuple of `jax ndarray` and `nested dictionary` that are supported by jax
return discounted_future_cashflows.mean(), mse

``````

My main question is regarding the set up of parameters and optimizer state(the code below), we could define model with `Lux.Chain` or `Flux.Chain`. (I don’t not how to do it in Pytorch as well)

``````Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)
``````

PS: if you need the full code, I could leave a colab link here.

Really appreciate the community!

In Julia, most dense layers are written as `WX+b`, which means if `X` is `k` by `n`, where `n` is the batch size, then it will be batch processed by the neural network and come out the other end. So if you just shape your input so that the batch values are in the 2nd dimension of your input, you should be able to achieve the same effect. E.g.

``````S=zeros(Float32, 3, batch_size)
for i in range(3)
S[:, i] .= Spot[i]
end
``````

Yeah I see. But the question is I could not do 1 training in just one matrix multiplication.
Here is the python equivalent of `jax.lax.scan` function

``````def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
``````

The `discount_future_cashflow` is carried to the next exercise day and we should use different \Theta to train it. In the end, The shape of \Theta for the first linear layer should be (49, 3, 64), the 49 stands for the m-1 exercise day’s parameter.

Stack can be his can be done in pure Julia and I think the equivalent of stack in julia is:

``````function stack(a)
return mapreduce(vcat, a) do a_i
reshape(a_i, 1, size(a_i)...)
end
end
``````

I’m not certain that autograd will work through this transformation but it’s worth checking.

It’s rather difficult to follow what’s happening in `compute_price`, but to me it looks like you’re instantiating 49 copies of the model parameters (1 for each date?). The equivalent in Flux would be to create 49 copies of the same model and pass them in. Instead of `scan`, you could calculate all the `value_if_wait`s in one go with `map` or an array comprehension, use `accumulate` to calculate the `discounted_future_cashflows`, and then a final `map`/comprehension/broadcast to calculate the MSEs.

Yeah, that is the only way I could come up with. I just thought since the design of `Lux` is similar to `jax` maybe Lux could handle it in a similar way.

Anyway, thanks for the help, I would try to implement it!

Lux is somewhere in between: you define 49 copies of the parameters and just one of the model, then call the model on each set of parameters in the inner loop. The important thing to note here is that you do not have to stack in either Julia or Python libraries. JAX/Optax should be smart enough to traverse the nested list of param copies without stacking, and Optimisers.jl (which works with both Flux and Lux) can do the same.

1 Like