Julia Equivalent to PyTorch's `torch.nn.functional.affine_grid`

I asked on Slack without much help. Is there an equivalent to the affine_grid function from Pytorch?

This numpy version will work if I want to port it to Julia but would rather see if any Julia packages provide a more robust version of this similar to Pytorch’s. I havent found one yet

def affine_grid_3d(theta, shape, align_corners=False):
    corners = np.ones(3) if align_corners else 1 - 1 / np.array(shape)
    grid_bins = [np.linspace(-c, c, s, dtype=theta.dtype) for c, s in zip(corners, shape)]
    
    x_grid, y_grid, z_grid = np.meshgrid(grid_bins[0], grid_bins[1], grid_bins[2], indexing='ij', copy=False)
    
    grid = np.stack([z_grid, y_grid, x_grid, np.ones_like(x_grid)])
    grid = grid.reshape(4, -1)
    grid = grid[None].repeat(len(theta), axis=0)
    grid = theta @ grid
    grid = grid.reshape(len(theta), 3, *shape)
    return grid.transpose(0, 2, 3, 4, 1)

    return grid