diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 353819673..b41c1f867 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,6 +40,7 @@ jobs: - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - '"DenseNet"' + - '"UNet"' - '["ConvNeXt", "ConvMixer"]' - 'r"Mixers"' - 'r"ViTs"' diff --git a/docs/src/index.md b/docs/src/index.md index 920cb8ba0..f65893a7a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -36,6 +36,7 @@ julia> ]add Metalhead | [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N | | [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N | | [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N | +| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N | To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl). diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b67bf1faf..8da7230d2 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -53,6 +53,7 @@ include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") +include("convnets/unet.jl") # Mixers include("mixers/core.jl") @@ -73,7 +74,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt, - MLPMixer, ResMLP, gMLP, ViT + MLPMixer, ResMLP, gMLP, ViT, UNet # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, @@ -81,7 +82,7 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, :EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt, - :MLPMixer, :ResMLP, :gMLP, :ViT) + :MLPMixer, :ResMLP, :gMLP, :ViT, :UNet) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl new file mode 100644 index 000000000..91200c0ce --- /dev/null +++ b/src/convnets/unet.jl @@ -0,0 +1,125 @@ +function pixel_shuffle_icnr(inplanes, outplanes; r = 2) + return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)), + Flux.PixelShuffle(r)) +end + +function unet_combine_layer(inplanes, outplanes) + return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...), + Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...)) +end + +function unet_middle_block(inplanes) + return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...), + Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...)) +end + +function unet_final_block(inplanes, outplanes) + return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), + Chain(basic_conv_bn((1, 1), inplanes, outplanes)...)) +end + +function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes) + return Chain(SkipConnection(Chain(m_child, + pixel_shuffle_icnr(midplanes, midplanes)), + Parallel(cat_channels, identity, BatchNorm(inplanes))), + relu, + unet_combine_layer(inplanes + midplanes, outplanes)) +end + +function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, + m_middle = _ -> (identity,)) + isempty(layers) && return m_middle(sz[end - 1]) + + layer, layers = layers[1], layers[2:end] + outsz = Flux.outputsize(layer, sz) + does_downscale = sz[1] รท 2 == outsz[1] + + if !does_downscale + return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) + elseif does_downscale && skip_upscale > 0 + return Chain(layer, + unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, + outplanes)...) + else + childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) + outsz = Flux.outputsize(childunet, outsz) + + inplanes = sz[end - 1] + midplanes = outsz[end - 1] + outplanes = isnothing(outplanes) ? inplanes : outplanes + + return unet_block(Chain(layer, childunet), + inplanes, midplanes, outplanes) + end +end + +""" + unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, + fdownscale::Integer = 0) + +Creates a UNet model with specified convolutional backbone. +Backbone of any Metalhead ResNet-like model can be used as encoder +([reference](https://arxiv.org/abs/1505.04597)). + +# Arguments + + - `encoder_backbone`: The backbone layers of specified model to be used as encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed + to instantiate a UNet with layers of resnet18 as encoder. + - `inputsize`: size of input image + - `outplanes`: number of output feature planes + - `final`: final block as described in original paper + - `fdownscale`: downscale factor +""" +function unet(encoder_backbone, imgdims, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0) + backbonelayers = collect(flatten_chains(encoder_backbone)) + layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, + skip_upscale = fdownscale) + + outsz = Flux.outputsize(layers, imgdims) + layers = Chain(layers, final(outsz[end - 1], outplanes)) + + return layers +end + +""" + UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, + encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) + +Creates a UNet model with an encoder built of specified backbone. By default it uses +[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder. +([reference](https://arxiv.org/abs/1505.04597)). + +# Arguments + + - `imsize`: size of input image + - `inchannels`: number of channels in input image + - `outplanes`: number of output feature planes. + - `encoder_backbone`: The backbone layers of specified model to be used as encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of + resnet18 as encoder. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + +!!! warning + + `UNet` does not currently support pretrained weights. + +See also [`Metalhead.unet`](@ref). +""" +struct UNet + layers::Any +end +@functor UNet + +function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, + encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) + layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) + + if pretrain + loadpretrain!(layers, string("UNet")) + end + return UNet(layers) +end + +(m::UNet)(x::AbstractArray) = m.layers(x) diff --git a/src/utilities.jl b/src/utilities.jl index 9f54107fd..316d884c6 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -79,3 +79,14 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) function _checkconfig(config, configs) @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end + +""" + flatten_chains(m::Chain) + flatten_chains(m) + +Convenience function for traversing nested layers of a Chain object and flatten them +into a single iterator. +""" +flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers) +flatten_chains(m) = (m,) + diff --git a/test/convnets.jl b/test/convnets.jl index af15ff413..dede5bdb3 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -347,3 +347,14 @@ end _gc() end end + +@testset "UNet" begin + encoder = Metalhead.backbone(ResNet(18)) + model = UNet((256, 256), 3, 10, encoder) + @test size(model(x_256)) == (256, 256, 10, 1) + @test gradtest(model, x_256) + + model = UNet() + @test size(model(x_256)) == (256, 256, 3, 1) + _gc() +end