Solve Raytrace ODE with events

Hi community:

Forgive my boldness. I am new to Julia, and the reason I am using it is because how fast it is. I come from Python, and I have a function for raytracing (below). I need to migrate it to Julia, but I’m not quite sure how to introduce the events to the ODE Solver. Thanks anyway.

Raytracing can be accomplished by solving the following ODE system::
\begin{cases} \frac{d\overrightarrow{x}}{dl} = \frac{\overrightarrow{p}}{s(\overrightarrow{x})}\\ \frac{d\overrightarrow{p}}{dl} = \nabla s(\overrightarrow{x})\\ \frac{dT}{dl} = s(\overrightarrow{x}) \end{cases} \qquad \overrightarrow{x}(l=0)=\overrightarrow{x_0}, \overrightarrow{p}(l=0)=\overrightarrow{p_0}

where \overrightarrow{x}=(x,z) s the position vector, \overrightarrow{p}=(p_x,p_z) is the rayparameter vector, s is the slowness, l is the lenght of the ray, y T is the traveltime.

import pandas as pd
import numpy as np
from scipy.integrate import solve_ivp

This is the right-hand-side function and the events

def rhsf(l, r, slowness, dsdx, dsdz, xaxis, zaxis, dx, dz):
    """RHS of raytracing ODE 
    
    Parameters
    ----------
    r : dependent variable containing (x, z, px, pz, t)
    l : indipendent variable l
    slowness : slowness 2d model
    dsdx : horizontal derivative of slowness 2d model
    dsdz : vertical derivative of slowness 2d model
    xaxis : horizontal axis
    zaxis : vertical axis
    dx : horizontal spacing
    dz : vertical spacing

    Returns
    -------
    drdt : RHS evaluation
    
    """
    m, n = slowness.shape
    # extract the different terms of the solution
    x = r[0]
    z = r[1]
    px = r[2]
    pz = r[3]
    drdt = np.zeros(len(r))

    # identify current position of the ray in the model
    xx = (x - xaxis[0]) // dx
    zz = (z - zaxis[0]) // dz
    xx = min([xx, n-1])
    xx = max([xx, 1])
    zz = min([zz, m-1])
    zz = max([zz, 1]) 

    # extract s, ds/dx, ds/dz at current position (nearest-neighbour interpolation)
    s = slowness[round(zz), round(xx)]
    dsdx = dsdx[round(zz), round(xx)]
    dsdz = dsdz[round(zz), round(xx)]
    
    # evaluate RHS
    drdt[0] = px/s
    drdt[1] = pz/s
    drdt[2] = dsdx
    drdt[3] = dsdz
    drdt[4] = s
    return drdt
def event_left(l, r, slowness, dsdx, dsdz, xaxis, zaxis, dx, dz):
    return r[0]-xaxis[0]
def event_right(l, r, slowness, dsdx, dsdz, xaxis, zaxis, dx, dz):
    return xaxis[-1]-r[0]
def event_top(l, r, slowness, dsdx, dsdz, xaxis, zaxis, dx, dz):
    return r[1]-zaxis[0]
def event_bottom(l, r, slowness, dsdx, dsdz, xaxis, zaxis, dx, dz):
    return zaxis[-1]-r[1]

event_left.terminal = True # set to True to trigger termination as soon as the condition is met
event_left.direction = -1 # set to -1 if wa went to stop when going from positive to negative outputs of event
event_right.terminal = True # set to True to trigger termination as soon as the condition is met
event_right.direction = -1 # set to -1 if wa went to stop when going from positive to negative outputs of event
event_top.terminal = True # set to True to trigger termination as soon as the condition is met
event_top.direction = -1 # set to -1 if wa went to stop when going from positive to negative outputs of event
event_bottom.terminal = True # set to True to trigger termination as soon as the condition is met
event_bottom.direction = -1 # set to -1 if wa went to stop when going from positive to negative outputs of event

This is the raytrace function

def raytrace(vel, xaxis, zaxis, dx, dz, lstep, source, thetas):
    """Raytracing for multiple rays defined by the initial conditions (source, thetas)
    
    Parameters
    ----------
    vel : np.ndarray
        2D Velocity model (nz x nx)
    xaxis : np.ndarray
        Horizonal axis 
    zaxis : np.ndarray
        Vertical axis 
    dx : float
        Horizonal spacing 
    dz : float
        Vertical spacing 
    lstep : np.ndarray
        Ray lenght axis
    source : tuple
        Source location
    thetas : tuple
        Take-off angles
    
    """
    # Slowness and its spatial derivatives
    slowness = 1./vel;
    [dsdz, dsdx] = np.gradient(slowness, dz, dx)

    df = pd.DataFrame(columns = ['Source', 'Theta', 'rx', 'rz'])

    for theta in thetas:
        # Initial condition
        r0=[source[0], source[1], 
            sin(theta * np.pi / 180) / vel[izs, ixs],
            cos(theta * np.pi / 180) / vel[izs, ixs], 0]

        # Solve ODE
        sol = solve_ivp(rhsf, [lstep[0], lstep[-1]], r0, t_eval=lstep, 
                        args=(slowness, dsdx, dsdz, x, z, dx, dz), events=[event_right, event_left,
                                                                           event_top, event_bottom])
        r = sol['y'].T
        
        # Display ray making sure we only plot the part of the ray that is inside the model
        zeros  = np.where(r[1:, 1] <= 0)[0]
        maxs = np.where(r[:, 1] >= max(z))[0]
        
        # Coordenadas del rayo y se guardan
        rx, ry = r[:,0]/1000, r[:,1]/1000

        for a in range(rx.size):

            append = pd.Series(
                {
                    'Source':source,
                    'Theta':theta,
                    'rx':rx[a],
                    'rz':ry[a]
                }
            )
            df = pd.concat([df, append.to_frame().T], ignore_index=True)

    return df

I think what you are looking for is the callback functionality available with the DifferentialEquations.jl ecosystem:

https://diffeq.sciml.ai/stable/features/callback_functions/