How to rewrite a jax model with stacked parameters

Hi, I have been using Julia for a while but new to the Julia ML ecosystem. I want to rewrite a lease square monte carlo model with neural nets using Lux or Flux.

Here is the main code from python’s jax, basically it would generate the path for an American option and calculate the loss for each exercise day and finally find the best \Theta that could well predict the price for every exercise day:

import jax.numpy as jnp
import jax
import haiku as hk
import optax

optimizer = optax.adam
lr = 1e-4

Spot = jnp.array([38, 36, 35])   # stock price
σ = jnp.array([0.2, .25, .3])     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
batch_size = 512
m = 50      # number of exercise dates
T = 1       # maturity
Δt = T / m  # interval between two exercise dates
n_stocks = 3
# simulates one step of the stock price evolution
def step(S, rng):
    ϵ = jax.random.normal(rng, S.shape)
    dB = jnp.sqrt(Δt) * ϵ
    S = S + r * S * Δt + σ * S * dB
    return S, S  # it returns two args is to use jax.lax.scan function

def payoff_put(S):
    return jnp.maximum(K - S, 0.)

def model(Si):
  # a simple nn model
  Si = jnp.column_stack([Si])
  out = (Si - 36.) / 5
  out = hk.Linear(64)(out)
  out = jax.nn.relu(out)
  out = hk.Linear(64)(out)
  out = jax.nn.relu(out)
  out = hk.Linear(1)(out)
  out = jnp.squeeze(out)
  return out

init, model = hk.without_apply_rng(hk.transform(model))

rng = jax.random.PRNGKey(0)
Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
  return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)

# LSMC algorithm
def compute_price(Θ, batch_size, rng):
    S = jnp.column_stack([jnp.ones(batch_size) * Spot[i] for i in range(3)]) # (batch_size, 3)
    rng_vector = jax.random.split(rng, m) # (m, 1)
    _, S = jax.lax.scan(step, S, rng_vector) # (m, batch_size, 3)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    def core(state, input):
        discounted_future_cashflows = state
        Si, Θi = input
        Y = discounted_future_cashflows
        value_if_wait = model(Θi, Si)

        mse = jnp.mean((value_if_wait - discounted_future_cashflows)**2)
        value_if_exercise = payoff_put(Si)
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(

        return discounted_future_cashflows, mse

    # Proceed recursively
    S = jnp.flip(S, 0)[1:]
    inputs = S, Θ
    discounted_future_cashflows, mse = jax.lax.scan(core, discounted_future_cashflows, inputs) #here inputs are tuple of `jax ndarray` and `nested dictionary` that are supported by jax
    return discounted_future_cashflows.mean(), mse

My main question is regarding the set up of parameters and optimizer state(the code below), we could define model with Lux.Chain or Flux.Chain. (I don’t not how to do it in Pytorch as well)

Θ = init(rng, jnp.ones([batch_size, n_stocks]))
Θ = init(rng, jnp.ones((batch_size,1)))
def stack(Θ):
  return jnp.stack([Θ] * 49)
Θ = jax.tree_map(stack, Θ)
opt_state = optimizer(lr).init(Θ)

PS: if you need the full code, I could leave a colab link here.

Really appreciate the community!

In Julia, most dense layers are written as WX+b, which means if X is k by n, where n is the batch size, then it will be batch processed by the neural network and come out the other end. So if you just shape your input so that the batch values are in the 2nd dimension of your input, you should be able to achieve the same effect. E.g.

S=zeros(Float32, 3, batch_size)
for i in range(3)
    S[:, i] .= Spot[i]

Yeah I see. But the question is I could not do 1 training in just one matrix multiplication.
Here is the python equivalent of jax.lax.scan function

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
  return carry, np.stack(ys)

The discount_future_cashflow is carried to the next exercise day and we should use different \Theta to train it. In the end, The shape of \Theta for the first linear layer should be (49, 3, 64), the 49 stands for the m-1 exercise day’s parameter.

Stack can be his can be done in pure Julia and I think the equivalent of stack in julia is:

function stack(a)
    return mapreduce(vcat, a) do a_i
        reshape(a_i, 1, size(a_i)...)

I’m not certain that autograd will work through this transformation but it’s worth checking.

It’s rather difficult to follow what’s happening in compute_price, but to me it looks like you’re instantiating 49 copies of the model parameters (1 for each date?). The equivalent in Flux would be to create 49 copies of the same model and pass them in. Instead of scan, you could calculate all the value_if_waits in one go with map or an array comprehension, use accumulate to calculate the discounted_future_cashflows, and then a final map/comprehension/broadcast to calculate the MSEs.

Yeah, that is the only way I could come up with. I just thought since the design of Lux is similar to jax maybe Lux could handle it in a similar way.

Anyway, thanks for the help, I would try to implement it! :grinning:

Lux is somewhere in between: you define 49 copies of the parameters and just one of the model, then call the model on each set of parameters in the inner loop. The important thing to note here is that you do not have to stack in either Julia or Python libraries. JAX/Optax should be smart enough to traverse the nested list of param copies without stacking, and Optimisers.jl (which works with both Flux and Lux) can do the same.

1 Like