How to make a Custom Layer work on GPU?

Hi. I’m trying to implement a funnel activation(FReLU) like this github for fun.
https://github.com/megvii-model/FunnelAct
An Pytorch implementation looks like this.

import torch
import torch.nn as nn
class FReLU(nn.Module):
    def __init__(self, in_c, k=3, s=1, p=1):
        super().__init__()
        self.f_cond = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
        self.bn = nn.BatchNorm2d(in_c)

    def forward(self, x):
        tx = self.bn(self.f_cond(x))
        out = torch.max(x,tx)
        return out

I thought it wold not be too hard so I implemented it in Flux like this.

struct FReLU
    c::DepthwiseConv
end
FReLU(k, ch ;stride=1, pad=1) = FReLU(DepthwiseConv(k, ch => ch; stride=stride, pad=pad))
(m::FReLU)(x::AbstractArray) = max.(x, m.c(x))
Flux.@functor FReLU

This implementation does not have the batch normalization, but I just wanted to test this.
When I Run a small example like below, it works.

julia> randn(Float32,5,5,1,1) |> Chain(Conv((3,3),1=>2,pad=1),FReLU((3,3),2),Conv((3,3),2=>1,pad=1))
 5×5×1×1 Array{Float32,4}:
[:, :, 1, 1] =
 -0.355185   0.0040399   0.0695772   0.186931  -0.072636
 -0.541725   0.140419   -0.111783   -0.526837   0.0547392
  0.259774   0.0427947  -0.356856   -0.308217  -0.0379689
  0.332452  -0.482085    0.117652    0.113239   0.515701
 -0.278544   0.153428    0.860948    0.845449   0.441734

But When I try to use GPU like below it fails.

julia> gpu(randn(Float32,5,5,1,1)) |> gpu(Chain(Conv((3,3),1=>2,pad=1),FReLU((3,3),2),Conv((3,3),2=>1,pad=1)))
ERROR: TaskFailedException:
scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/shogo/.julia/packages/GPUArrays/eVYIC/src/host/indexing.jl:41
 [3] getindex at /home/shogo/.julia/packages/GPUArrays/eVYIC/src/host/indexing.jl:96 [inlined]
 [4] _getindex at ./abstractarray.jl:1083 [inlined]
 [5] getindex at ./abstractarray.jl:1060 [inlined]
 [6] im2col!(::CuArray{Float32,2}, ::CuArray{Float32,4}, ::DenseConvDims{3,(3, 3, 1),2,2,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/impl/conv_im2col.jl:231
 [7] macro expansion at /home/shogo/.julia/packages/NNlib/PI8Xh/src/impl/depthwiseconv_im2col.jl:34 [inlined]
 [8] (::NNlib.var"#427#threadsfor_fun#172"{CuArray{Float32,3},Float32,Float32,CuArray{Float32,5},CuArray{Float32,5},CuArray{Float32,5},DepthwiseConvDims{3,(3, 3, 1),2,1,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false},Int64,Int64,Int64,DenseConvDims{3,(3, 3, 1),2,2,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false},UnitRange{Int64}})(::Bool) at ./threadingconstructs.jl:81
 [9] (::NNlib.var"#427#threadsfor_fun#172"{CuArray{Float32,3},Float32,Float32,CuArray{Float32,5},CuArray{Float32,5},CuArray{Float32,5},DepthwiseConvDims{3,(3, 3, 1),2,1,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false},Int64,Int64,Int64,DenseConvDims{3,(3, 3, 1),2,2,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false},UnitRange{Int64}})() at ./threadingconstructs.jl:48
Stacktrace:
 [1] wait at ./task.jl:267 [inlined]
 [2] threading_run(::Function) at ./threadingconstructs.jl:34
 [3] macro expansion at ./threadingconstructs.jl:93 [inlined]
 [4] depthwiseconv_im2col!(::CuArray{Float32,5}, ::CuArray{Float32,5}, ::CuArray{Float32,5}, ::DepthwiseConvDims{3,(3, 3, 1),2,1,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}; col::CuArray{Float32,3}, alpha::Float32, beta::Float32) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/impl/depthwiseconv_im2col.jl:30
 [5] depthwiseconv_im2col! at /home/shogo/.julia/packages/NNlib/PI8Xh/src/impl/depthwiseconv_im2col.jl:18 [inlined]
 [6] #depthwiseconv!#106 at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:191 [inlined]
 [7] depthwiseconv!(::CuArray{Float32,5}, ::CuArray{Float32,5}, ::CuArray{Float32,5}, ::DepthwiseConvDims{3,(3, 3, 1),2,1,(1, 1, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:191
 [8] depthwiseconv!(::CuArray{Float32,4}, ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DepthwiseConvDims{2,(3, 3),2,1,(1, 1),(1, 1, 1, 1),(1, 1),false}; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:148
 [9] depthwiseconv! at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:148 [inlined]
 [10] depthwiseconv(::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DepthwiseConvDims{2,(3, 3),2,1,(1, 1),(1, 1, 1, 1),(1, 1),false}; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:91
 [11] depthwiseconv(::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DepthwiseConvDims{2,(3, 3),2,1,(1, 1),(1, 1, 1, 1),(1, 1),false}) at /home/shogo/.julia/packages/NNlib/PI8Xh/src/conv.jl:89
 [12] (::DepthwiseConv{2,4,typeof(identity),CuArray{Float32,4},CuArray{Float32,1}})(::CuArray{Float32,4}) at /home/shogo/.julia/packages/Flux/05b38/src/layers/conv.jl:390
 [13] (::FReLU)(::CuArray{Float32,4}) at /home/shogo/tmp/model-zoo/vision/mnist/conv.jl:58
 [14] applychain(::Tuple{FReLU,Conv{2,4,typeof(identity),CuArray{Float32,4},CuArray{Float32,1}}}, ::CuArray{Float32,4}) at /home/shogo/.julia/packages/Flux/05b38/src/layers/basic.jl:36 (repeats 2 times)
 [15] Chain at /home/shogo/.julia/packages/Flux/05b38/src/layers/basic.jl:38 [inlined]
 [16] |>(::CuArray{Float32,4}, ::Chain{Tuple{Conv{2,4,typeof(identity),CuArray{Float32,4},CuArray{Float32,1}},FReLU,Conv{2,4,typeof(identity),CuArray{Float32,4},CuArray{Float32,1}}}}) at ./operators.jl:834
 [17] top-level scope at REPL[11]:1
 [18] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1088

Does someone know how to make this work on GPU?

This looks like a problem in the NNlib implementation of depthwiseconv. I get the same error with

using Flux, CUDA
CUDA.allowscalar(false)
gpu(rand(5,5,2,1)) |> gpu(DepthwiseConv((3,3), 2=>2))

This issue makes it look like a GPU version of depthwiseconv has not yet been implemented in NNlib.

1 Like

Following some links from that issue, https://github.com/JuliaGPU/CuArrays.jl/pull/523 is the PR to follow.

1 Like

@contradict
Oh, I see. It makes sense now. Thanks!

@ToucheSir
I’ll look at the PR. Thanks for the link!