How daggerflux works?

Dear All,

@dhairyagandhi96 Manytimes in past, I wanted to use GitHub - FluxML/DaggerFlux.jl: Distributed computation of differentiation pipelines to use multiple workers, devices, GPU, etc. since Julia wasn't fast enough already for my machine learning projects, since my models naturally forms DAGs and the concurrency offered by Dagger seems nice. Also, we have recently touch this problem on slack, where we have been discussing distribution of a large language model across multiple GPUs, for which dagger seems to be nice. But I am naturally afraid to use project, which I do not understand internally. I do not need complete knowledge, but good knowledge is nice to have.

I would therefore like to ask, if someone can explain me, how the ChainRules goes together with dagger. My understanding is that as the computational model is executed, it needs to add tasks to computational graph used by dagger. That would mean that dagger can use dynamic computational graph. Or does it work differently? A some sort of explanation would be nice to have.

Thanks,
Tomas

1 Like

Hi,

The package is written in such a manner that the forwards pass computation also generates a delayed thunk containing the pullback at every layer in the Chain. It currently only supports parallelisation via the Chain, but the behaviour (such as for Parallel where we have the opportunity for further parallelisation) can be customised using daglayer.

The delayed pullback call is effectively what contains the call to any of the ChainRules. Since the call to the forwards pass can be generated dynamically, so too does the backwards graph since we use the same mechanism to construct the backwards pass as the forwards. @jpsamaroo can also talk about the specifics of execution here.

It would be excellent to get some traction and Iā€™d be happy to help it work out for getting parallelism for LLMs, even though I anticipate that that would need a few upgrades

Best,
Dhairya

2 Likes