How to rewrite a jax model with stacked parameters

It’s rather difficult to follow what’s happening in compute_price, but to me it looks like you’re instantiating 49 copies of the model parameters (1 for each date?). The equivalent in Flux would be to create 49 copies of the same model and pass them in. Instead of scan, you could calculate all the value_if_waits in one go with map or an array comprehension, use accumulate to calculate the discounted_future_cashflows, and then a final map/comprehension/broadcast to calculate the MSEs.