Conditional Logit MLP in Flux

I have been trying for awhile to get Conditional Logistic Regression (mathematically equivalent to Cox Proportional Hazards model) working with MLP’s in Flux, and have found a way, though it is clumsy. I am wondering if anyone knows (or might suggest) a better way? [I also would like to request an additional feature for Flux to make this straightforward and efficient?]
Otherwise, readers interested in this may find my approach helpful.

Typically, a single-output MLP produces a row output in Flux.
A Conditional Logistic Regression (Conditional Logit) MLP requires such output forecasts from an MLP to be grouped into Strata and the softmax applied across multiple columns of forecast, according to group. [It is noteworthy that typically such groups vary in size - in fact, this would always be the case in Cox Regression for Survival Analysis].

I have managed, (with considerable difficulty, considering the somewhat ‘temperamental’ nature of Zygote with regard to user-supplied code) to figure out a way to achieve this, which I will describe in notional terms below.

Assuming mini-batches each consist of multiple strata (where strata are each contained in single batch), then a (square) ‘renormalisaion’ array may be added as an extra set of rows to the input (‘x’) matrix. This extra matrix consists solely of 0’s and 1’s and is designed to have the property that matrix multiplication by it will produce the group-wise renormalisation (element-wise) divisor required by the softmax. [One needs to take care that these supplementary rows are excluded from the computation of the MLP Chain].
Thus one may produce the softmax forecast which may be input to the crossentropy or whatever other desired loss function you may choose.

A few notes on the the diagonal renormalisation matrices: If the strata are of the form [1;1;1;1;2;2;2;3;3;4;4;4;4;4…] etc. then the renormalisation matrix will be block-diagonal with each block consisting of ones. While ideally such matrices would be specified by Sparse Matrices, these are incompatible with Zygote, and hence I used Int32 matrices. (I had also tried Binary because they would automatically convert with multiplication, but Zygote really did not appreciate this). As to why these matrices need to be pre-computed and passed as variables, apart from speed, it seems that there is no way of computing them dynamically that did not throw up ‘mutability’ errors from Zygote.

One other simpler approach was to have each batch be a stratum, and hence the renormalisation would be performed once across each Batch. This is technically correct, but this limits training and proved to be undesirable in the cases I looked at.

I would be grateful for any ideas.

Thanks

PS Another approach I tried, apparently possible in PyTorch, is to organise strata as individual input observations, where for each observation the input data matrix is rectangular and the output target is a ‘one hot’ column vector. I could not see how to apply an MLP chain to an unflattened matrix, though.

Hi,
could you provide a simple pytorch example of what you are trying to achieve?

Hi Thanks for having a look.

Unfortunately, since my friend did the Pytorch version, and I don’t understand enough to isolate the relevent block, I would have to give the entire program (space doesn’t permit me to do this, I don’t think my friend would want this).

However, I think I can describe what would be required, for the analogous solution to the Pytorch one

For the Pytorch-analogy solution, the following would have to be possible:

Is there any way in Flux that the input data for a standard MLP could be 3-dimensional, where the input for each observation would be a matrix, and where the output would be a column? [Each column would be the result of evaluating the same MLP on a different input vector from the matrix]

Alternatively, is there any way to sub-divide batches into sub-batches, and to perform operations on the sub-batches?

Thanks again

PS I have made some progress:

If I define an MLP chain as m(x) and then make my batches consist of lists of input arrays (each input matrix corresponding to a single stratum), then by defining
m2(x)=m.(x)’

a list of vector outputs of the right dimensions results when I evaluate a batch. I am having trouble taking it from there.

If you have a grasp on how the indexing and contractions for this N-D dense operation should go, GitHub - mcabbott/Tullio.jl: ⅀ is worth a try.

Edit: looking at the implementations of Flux.Dense and torch.nn.Linear, there doesn’t seem to be a reason why you couldn’t pass a higher-dimensional input to the former as well. Have you tried this? Seems like the only hurdle would be row- vs column-wise dimensions, though it’s hard to tell since PyTorch may be broadcasting the matmul implicitly. I know you don’t have an easy MWE, but one straightforward way to test this would be
a) to look at the input and output dims before/after/in-between layers in your PyTorch model,
b) replicate and test those layers and inputs (i.e. same shapes and params) in isolation, then
c) repeat with equivalent layers and inputs in Flux (accounting for column-major and batch-last).

Thank you for that. I think that understanding that is beyond my level of expertise in this area. However, one thing I know is that Zygote does not allow you to define an array and then fill it up, as this is considered to be a form of ‘mutation’. [You can’t even do this and then rename it, although just thinking about it, I didn’t try ‘copy’.]

My way of passing a ‘helper’ array as part of the data seems to work OK.

Thanks for your thoughts.