Hi, I have been using Julia for a while but new to the Julia ML ecosystem. I want to rewrite a lease square monte carlo model with neural nets using
Here is the main code from python’s
jax, basically it would generate the path for an American option and calculate the loss for each exercise day and finally find the best \Theta that could well predict the price for every exercise day:
import jax.numpy as jnp import jax import haiku as hk import optax optimizer = optax.adam lr = 1e-4 Spot = jnp.array([38, 36, 35]) # stock price σ = jnp.array([0.2, .25, .3]) # stock volatility K = 40 # strike price r = 0.06 # risk free rate n = 100000 # Number of simualted paths batch_size = 512 m = 50 # number of exercise dates T = 1 # maturity Δt = T / m # interval between two exercise dates n_stocks = 3 # simulates one step of the stock price evolution def step(S, rng): ϵ = jax.random.normal(rng, S.shape) dB = jnp.sqrt(Δt) * ϵ S = S + r * S * Δt + σ * S * dB return S, S # it returns two args is to use jax.lax.scan function def payoff_put(S): return jnp.maximum(K - S, 0.) def model(Si): # a simple nn model Si = jnp.column_stack([Si]) out = (Si - 36.) / 5 out = hk.Linear(64)(out) out = jax.nn.relu(out) out = hk.Linear(64)(out) out = jax.nn.relu(out) out = hk.Linear(1)(out) out = jnp.squeeze(out) return out init, model = hk.without_apply_rng(hk.transform(model)) rng = jax.random.PRNGKey(0) Θ = init(rng, jnp.ones([batch_size, n_stocks])) Θ = init(rng, jnp.ones((batch_size,1))) def stack(Θ): return jnp.stack([Θ] * 49) Θ = jax.tree_map(stack, Θ) opt_state = optimizer(lr).init(Θ) # LSMC algorithm def compute_price(Θ, batch_size, rng): S = jnp.column_stack([jnp.ones(batch_size) * Spot[i] for i in range(3)]) # (batch_size, 3) rng_vector = jax.random.split(rng, m) # (m, 1) _, S = jax.lax.scan(step, S, rng_vector) # (m, batch_size, 3) discount = jnp.exp(-r * Δt) # Very last date value_if_exercise = payoff_put(S[-1]) discounted_future_cashflows = value_if_exercise * discount def core(state, input): discounted_future_cashflows = state Si, Θi = input Y = discounted_future_cashflows value_if_wait = model(Θi, Si) mse = jnp.mean((value_if_wait - discounted_future_cashflows)**2) value_if_exercise = payoff_put(Si) exercise = value_if_exercise >= value_if_wait discounted_future_cashflows = discount * jnp.where( exercise, value_if_exercise, discounted_future_cashflows) return discounted_future_cashflows, mse # Proceed recursively S = jnp.flip(S, 0)[1:] inputs = S, Θ discounted_future_cashflows, mse = jax.lax.scan(core, discounted_future_cashflows, inputs) #here inputs are tuple of `jax ndarray` and `nested dictionary` that are supported by jax return discounted_future_cashflows.mean(), mse
My main question is regarding the set up of parameters and optimizer state(the code below), we could define model with
Flux.Chain. (I don’t not how to do it in Pytorch as well)
Θ = init(rng, jnp.ones([batch_size, n_stocks])) Θ = init(rng, jnp.ones((batch_size,1))) def stack(Θ): return jnp.stack([Θ] * 49) Θ = jax.tree_map(stack, Θ) opt_state = optimizer(lr).init(Θ)
PS: if you need the full code, I could leave a colab link here.
Really appreciate the community!