Change loss function depending on the iteration count within Optim.jl

Hi,

let’s say I have a loss function which I want to change depending on the iteration count.

For example:

function l(x, target)
    target_warp = f_warp(target, iteration_count)
    sum(abs2, x .- target_warp)
end

Is there a way to do that? I am aware that the callback function has access to the iteration count so I could do some trick with a Ref but not sure how elegant this is.

Best,

Felix

PS: discourse reminded me that I joined this community exactly 5 years ago today haha

Hi @roflmaostc,

Why do you want to do this?

p.s., :partying_face: :birthday_cake:

1 Like

For Gumbel Softmax there is a temperature parameter which is gradually reduced during training.

But in general, my loss function needs to be changed during training to push my solution towards a favorable state.

Whether this is a good idea depends on the algorithm you are using.

You could use your Ref trick, but note that some algorithms may not have a direct one-to-one relationship between “evaluating l” and “the iteration number”. For example, they may do a line search, or evaluate l at multiple points before taking a step.

I could use the callback which has access to the iteration number to change it.

But yes, calling loss is not directly the iteration number.