I have been trying for awhile to get Conditional Logistic Regression (mathematically equivalent to Cox Proportional Hazards model) working with MLP’s in Flux, and have found a way, though it is clumsy.   I am wondering if anyone knows (or might suggest) a better way?   [I also would like to request an additional feature for Flux to make this straightforward and efficient?]
Otherwise, readers interested in this may find my approach helpful.
Typically, a single-output MLP produces a row output in Flux.
A Conditional Logistic Regression (Conditional Logit) MLP requires such output forecasts from an MLP to be grouped into Strata and the softmax applied across multiple columns of forecast, according to group.   [It is noteworthy that typically such groups vary in size - in fact, this would always be the case in Cox Regression for Survival Analysis].
I have managed, (with considerable difficulty, considering the somewhat ‘temperamental’ nature of Zygote with regard to user-supplied code) to figure out a way to achieve this, which I will describe in notional terms below.
Assuming mini-batches each consist of multiple strata (where strata are each contained in single batch), then a (square) ‘renormalisaion’ array may be added as an extra  set of  rows to the input (‘x’) matrix.   This extra matrix consists solely of 0’s and 1’s and is designed to have the property that matrix multiplication by it will produce the group-wise renormalisation (element-wise) divisor required by the softmax.  [One needs to take care that these supplementary rows are excluded from the computation of the MLP Chain].
Thus one may produce the softmax forecast which may be input to the crossentropy or whatever other desired loss function you may choose.
A few notes on the the diagonal renormalisation matrices: If the strata are of the form [1;1;1;1;2;2;2;3;3;4;4;4;4;4…] etc. then the renormalisation matrix will be block-diagonal with each block consisting of ones. While ideally such matrices would be specified by Sparse Matrices, these are incompatible with Zygote, and hence I used Int32 matrices. (I had also tried Binary because they would automatically convert with multiplication, but Zygote really did not appreciate this). As to why these matrices need to be pre-computed and passed as variables, apart from speed, it seems that there is no way of computing them dynamically that did not throw up ‘mutability’ errors from Zygote.
One other simpler approach was to have each batch be a stratum, and hence the renormalisation would be performed once across each Batch. This is technically correct, but this limits training and proved to be undesirable in the cases I looked at.
I would be grateful for any ideas.
Thanks
PS Another approach I tried, apparently possible in PyTorch, is to organise strata as individual input observations, where for each observation the input data matrix is rectangular and the output target is a ‘one hot’ column vector. I could not see how to apply an MLP chain to an unflattened matrix, though.