Thanks!
Pretty neat that the call to gradient checks if forward mode will be beneficial! I guess we had implicitly made the assumption that the gradient function would always use reverse mode, which is not really justified.
BTW: I was a bit confused about why use ForwardDiff instead of ChainRulesCore.frule. Found an answer to this here.