Function return type depends on an input flag : Bad?

I have a function that calculates ‘sky’ loss and optionally its gradient.

function _sky_loss_gradient(x,  loss_only::Bool)
    # calculation
    loss_only && return loss
    # more calculation
    return loss, grad
end

sky_loss(x) = _sky_loss_gradient(x, true)

I was wondering if this is a good practice? As the compiler might not know the output type at run time. This is my use case: I am defining an autodiff reverse mode chain rule for backpropagation of loss gradients.

import ChainRulesCore: rrule, DoesNotExist, NO_FIELDS

function rrule(::typeof(sky_loss), x)
    loss, grad = _sky_loss_gradient(exite, labels, false)
    _pullback(x̄) = (NO_FIELDS, x̄.*grad, DoesNotExist())
    loss, _pullback
end

Will the compiler be able to figure out that sky_loss returns a single value (not a tuple).

Or should I be using Val(true)? Something like…

function _sky_loss_gradient(x, ::Val{loss_only}) where loss_only
    # calculation
    loss_only && return loss
    # more calculation
    return loss, grad
end

sky_loss(x) = _sky_loss_gradient(x, Val(true))

Is there any other simpler way?

Edit: Users will never call _sky_loss_gradient. They will only use sky_loss and gradients are for autodiff via Zygote or Flux.

Yes, that is generally bad - it’s the definition of type instability (output type depending on the input values, not types) and moreover, users of your function won’t even know how many objects your function will return in the general case!

As your code is now, users will basically always have to do this dance:

res = sky_loss_gradient(inp, some_bool)
if res isa Tuple
   # some code
   # that
   # may be very long
else
   # ok we don't have a gradient..?
end

If you only ever pass boolean literals into sky_loss_gradient, constant propagation should take care of this, yes. That being said, it’s probably best to handle those cases explicitly in rrule and seperate loss and gradient into two functions.

2 Likes

Thank you.
Users will never call _sky_loss_gradient (I added an underscore :slightly_smiling_face:). The gradient is obtained only via Zygote or flux. So now can the compiler figure out that sky_loss does not return a tuple? _sky_loss_gradient is called explicitly only in two places once with a true and once with a false.

I can not write two functions as the calculations are highly intertwined and editing two functions parallely is buggy.
What do you think of Val(true) approach?

I think boolean flags are generally bad. I’d just have two functions. They don’t look intertwined to me, one can call the other and modify the result.

skyloss(x) = ...

function skyloss_and_grad(x)
 loss = skyloss(x)
 grad = ...
 loss, grad
end

To add to this, if they are intertwined, return both, and let the user do (loss, ~) = ... if they want to.

@gustaphe , Here more calculation depends on calculation. I am trying to avoid more calculation.

1 Like

Let’s get at this from a different angle. In which situations will you need the gradient, and in which will you only need the loss? Maybe there’s some identifying characteristic that you can use dispatch with to get rid of the instability?

Well, as far as your “flag” is simple and can be resolved as compile-time constant, you can rely on constant prop’ performed by type inference:

julia> function foo(a, flag)
           r = a > 0
           if flag
               return a, r
           else
               r
           end
       end
foo (generic function with 1 method)

julia> code_typed((Int,)) do n
           foo(n, true) # compiler should fully understand the return type here
       end |> first
CodeInfo(
1 ─ %1 = Base.slt_int(0, n)::Bool
│   %2 = Core.tuple(n, %1)::Tuple{Int64, Bool}
└──      return %2
) => Tuple{Int64, Bool}

It’s not well documented which data can be constant-folded, but you can definitely use Int, Bool, Symbol, for example (N.B. the value of String won’t be constant-folded).
And IMHO there won’t be so many use cases of Val type, especially when it wraps a value that can be compile-time constant.

Having said that, if you can refactor your function so that its return types doesn’t depend on argument value, however it can be constant-folded, then it’s definitely better. Julia’s inference may give up constant prop’ when encountering a mutually recursive call cycle, for example, and debugging inference issues in such cases might be tricky.

2 Likes

Having said that, if you can refactor your function so that its return types doesn’t depend on argument value, however it can be constant-folded, then it’s definitely better. Julia’s inference may give up constant prop’ when encountering a mutually recursive call cycle, for example, and debugging inference issues in such cases might be tricky.

Note that this is especially true when using AD. Zygote tends to generate code with a lot of closures, which really makes it difficult for the compiler to do constant propagation. That’s why I would generally recommend against relying on constant propagation too heavily when dealing with AD.

3 Likes

Is this useful:

function _sky_loss_gradient(x,  loss_only::Bool)
    # calculation
    loss_only && return loss,nothing
    # more calculation
    return loss, grad
end

the return type is always a tuple, and the boolean determines whether the second item in the tuple contains a useful value.

It doesn’t change the fact that the type is unstable - you’ve just moved the instability from typeof(loss) vs. Tuple{typeof(loss), typeof(grad)} to Tuple{typeof(loss), Nothing} vs. Tuple{typeof(loss), typeof(grad)}. I guess the advantage is that there’s no question how many elements there are, so destructuring is an option now.

Julia provides great tools for exploring type stability yourself:

function lossgrad(x, loss_only::Bool)
    loss = 0.5sum(abs2, x)
    loss_only && return loss
    loss, x
end
function lossgrad(x, ::Val{loss_only}) where {loss_only}
    loss = 0.5sum(abs2, x)
    loss_only && return loss
    loss, x
end
x = randn(10);
@code_warntype lossgrad(x, true)
@code_warntype lossgrad(x, Val(true))
@code_warntype lossgrad(x, Val(false))
3 Likes

Yep - that only works if the ::Val{true} is written in statically or derived from type information though, as far as I know. So having that Val be dynamic will (as far as I understand) still lead to type instability.

Yes, but it avoids relying on const prop, which was the concern over:

function rrule(::typeof(sky_loss), x)
    loss, grad = _sky_loss_gradient(exite, labels, false)
    _pullback(x̄) = (NO_FIELDS, x̄.*grad, DoesNotExist())
    loss, _pullback
end

I think boolean flags and type instability are almost always bad design, regardless of their performance consequences. Make simple functions with clear interfaces. If you find that makes it faster too, that’s just gravy – the thing to get right first is design.

1 Like

Great. This is what I ended up doing.