From ba54cf0219cc247a732c1710925a4eef7a2c70c7 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 27 Dec 2022 20:30:25 +0530 Subject: [PATCH 01/29] model implemented --- src/convnets/unet.jl | 67 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/convnets/unet.jl diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl new file mode 100644 index 000000000..ccf49d4d0 --- /dev/null +++ b/src/convnets/unet.jl @@ -0,0 +1,67 @@ +function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) + Chain(conv1 = Conv(kernel, in_chs => out_chs, pad = (1, 1); init = _random_normal), + norm1 = BatchNorm(out_chs, relu), + conv2 = Conv(kernel, out_chs => out_chs, pad = (1, 1); init = _random_normal), + norm2 = BatchNorm(out_chs, relu)) +end + +function UpConvBlock(in_chs::Int, out_chs::Int, kernel = (2, 2)) + Chain(convtranspose = ConvTranspose(kernel, in_chs => out_chs, stride = (2, 2); init = _random_normal), + norm = BatchNorm(out_chs, relu)) +end + +struct Unet + encoder::Any + decoder::Any + upconv::Any + pool::Any + bottleneck::Any + final_conv::Any +end + +@functor Unet + +function Unet(inplanes::Int = 3, outplanes::Int = 1, init_features::Int = 32) + + features = init_features + + encoder_layers = [] + append!(encoder_layers, [unet_block(inplanes, features)]) + append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i ∈ 0:2]) + + encoder = Chain(encoder_layers) + + decoder = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i ∈ 0:3]) + + pool = Chain([MaxPool((2, 2), stride = (2, 2)) for _ ∈ 1:4]) + + upconv = Chain([UpConvBlock(features * 2^(i + 1), features * 2^i) for i ∈ 3:-1:0]) + + bottleneck = _block(features * 8, features * 16) + + final_conv = Conv((1, 1), features => outplanes) + + Unet(encoder, decoder, upconv, pool, bottleneck, final_conv) +end + +function (u::Unet)(x::AbstractArray) + enc_out = [] + + out = x + for i ∈ 1:4 + out = u.encoder[i](out) + push!(enc_out, out) + + out = u.pool[i](out) + end + + out = u.bottleneck(out) + + for i ∈ 4:-1:1 + out = u.upconv[5-i](out) + out = cat(out, enc_out[i], dims = 3) + out = u.decoder[i](out) + end + + σ(u.final_conv(out)) +end From 11c50d969e67e39313e1e5f3aec3b01de0a319c3 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 27 Dec 2022 20:43:03 +0530 Subject: [PATCH 02/29] adding documentation --- docs/src/index.md | 1 + src/convnets/unet.jl | 30 ++++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 920cb8ba0..f65893a7a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -36,6 +36,7 @@ julia> ]add Metalhead | [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N | | [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N | | [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N | +| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N | To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl). diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index ccf49d4d0..cf0261cef 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -10,7 +10,22 @@ function UpConvBlock(in_chs::Int, out_chs::Int, kernel = (2, 2)) norm = BatchNorm(out_chs, relu)) end -struct Unet +""" + UNet(inplanes::Integer = 3, outplanes::Integer = 1, init_features::Integer = 32) + + Create a UNet model + ([reference](https://arxiv.org/abs/1505.04597v1)) + + # Arguments + - `in_channels`: The number of input channels + - `inplanes`: The number of input planes to the network + - `outplanes`: The number of output features + +!!! warning + + `UNet` does not currently support pretrained weights. +""" +struct UNet encoder::Any decoder::Any upconv::Any @@ -18,15 +33,14 @@ struct Unet bottleneck::Any final_conv::Any end +@functor UNet -@functor Unet +function UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = 1) -function Unet(inplanes::Int = 3, outplanes::Int = 1, init_features::Int = 32) - - features = init_features + features = inplanes encoder_layers = [] - append!(encoder_layers, [unet_block(inplanes, features)]) + append!(encoder_layers, [unet_block(in_channels, features)]) append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i ∈ 0:2]) encoder = Chain(encoder_layers) @@ -41,10 +55,10 @@ function Unet(inplanes::Int = 3, outplanes::Int = 1, init_features::Int = 32) final_conv = Conv((1, 1), features => outplanes) - Unet(encoder, decoder, upconv, pool, bottleneck, final_conv) + UNet(encoder, decoder, upconv, pool, bottleneck, final_conv) end -function (u::Unet)(x::AbstractArray) +function (u::UNet)(x::AbstractArray) enc_out = [] out = x From ca73586da0da468d244ac768d4aec3b400a04061 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Wed, 28 Dec 2022 10:05:04 +0530 Subject: [PATCH 03/29] ran juliaformatter --- docs/make.jl | 42 ++- src/Metalhead.jl | 9 +- src/convnets/alexnet.jl | 2 +- src/convnets/builders/resnet.jl | 2 +- src/convnets/convnext.jl | 2 +- src/convnets/densenet.jl | 2 +- src/convnets/efficientnets/efficientnet.jl | 2 +- src/convnets/efficientnets/efficientnetv2.jl | 2 +- src/convnets/inceptions/googlenet.jl | 2 +- src/convnets/inceptions/inceptionresnetv2.jl | 2 +- src/convnets/inceptions/inceptionv3.jl | 2 +- src/convnets/inceptions/inceptionv4.jl | 2 +- src/convnets/inceptions/xception.jl | 2 +- src/convnets/mobilenets/mnasnet.jl | 2 +- src/convnets/mobilenets/mobilenetv1.jl | 2 +- src/convnets/mobilenets/mobilenetv2.jl | 2 +- src/convnets/mobilenets/mobilenetv3.jl | 2 +- src/convnets/resnets/core.jl | 4 +- src/convnets/resnets/res2net.jl | 4 +- src/convnets/resnets/resnext.jl | 2 +- src/convnets/resnets/seresnet.jl | 4 +- src/convnets/unet.jl | 100 ++--- src/utilities.jl | 6 + test/convnets.jl | 375 +++++++++---------- test/runtests.jl | 19 +- test/vits.jl | 14 +- 26 files changed, 305 insertions(+), 304 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 72fb486c9..7074db5fd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,26 +1,34 @@ -using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux +using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, + DataAugmentation, Flux DocMeta.setdocmeta!(Metalhead, :DocTestSetup, :(using Metalhead); recursive = true) -makedocs(modules = [Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux], +makedocs(; + modules = [ + Metalhead, + Artifacts, + LazyArtifacts, + Images, + OneHotArrays, + DataAugmentation, + Flux, + ], sitename = "Metalhead.jl", doctest = false, pages = ["Home" => "index.md", - "Tutorials" => [ - "tutorials/quickstart.md", - ], - "Developer guide" => "contributing.md", - "API reference" => [ - "api/reference.md", - ], - ], - format = Documenter.HTML( - canonical = "https://fluxml.ai/Metalhead.jl/stable/", - # analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true"), - ) + "Tutorials" => [ + "tutorials/quickstart.md", + ], + "Developer guide" => "contributing.md", + "API reference" => [ + "api/reference.md", + ], + ], + format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/", + # analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true")) -deploydocs(repo = "github.com/FluxML/Metalhead.jl.git", +deploydocs(; repo = "github.com/FluxML/Metalhead.jl.git", target = "build", push_preview = true) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b67bf1faf..80da39bb8 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -2,6 +2,7 @@ module Metalhead using Flux using Flux: Zygote, outputsize +using Distributions: Normal using Functors using BSON using Artifacts, LazyArtifacts @@ -10,7 +11,7 @@ using MLUtils using PartialFunctions using Random -import Functors +using Functors: Functors include("utilities.jl") @@ -28,6 +29,8 @@ include("convnets/builders/stages.jl") ## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") +## Unet +include("convnets/unet.jl") ## ResNets include("convnets/resnets/core.jl") include("convnets/resnets/res2net.jl") @@ -66,7 +69,7 @@ include("vit-based/vit.jl") # Load pretrained weights include("pretrain.jl") -export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, +export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, UNet, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, @@ -76,7 +79,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, MLPMixer, ResMLP, gMLP, ViT # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, +for T in (:AlexNet, :VGG, :UNet, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index fba6749e4..d05919f24 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -42,7 +42,7 @@ Create a `AlexNet`. - `nclasses`: the number of output classes !!! warning - + `AlexNet` does not currently support pretrained weights. See also [`alexnet`](@ref). diff --git a/src/convnets/builders/resnet.jl b/src/convnets/builders/resnet.jl index 580baaa34..231b9ef30 100644 --- a/src/convnets/builders/resnet.jl +++ b/src/convnets/builders/resnet.jl @@ -5,7 +5,7 @@ Creates a generic ResNet-like model. !!! info - + This is a very generic, flexible but low level function that can be used to create any of the ResNet variants. For a more user friendly function, see [`Metalhead.resnet`](@ref). diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index a38840d0b..2c76d6a30 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -97,7 +97,7 @@ Creates a ConvNeXt model. - `nclasses`: number of output classes !!! warning - + `ConvNeXt` does not currently support pretrained weights. See also [`Metalhead.convnext`](@ref). diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 4799e22a7..5b340a8ca 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -122,7 +122,7 @@ Create a DenseNet model with specified configuration. Currently supported values Set `pretrain = true` to load the model with pre-trained weights for ImageNet. !!! warning - + `DenseNet` does not currently support pretrained weights. See also [`Metalhead.densenet`](@ref). diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 2657a3884..53f5c3569 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -73,7 +73,7 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `nclasses`: number of output classes. !!! warning - + EfficientNet does not currently support pretrained weights. See also [`Metalhead.efficientnet`](@ref). diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index dab159e68..6b89eb0ea 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -76,7 +76,7 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - `nclasses`: number of output classes !!! warning - + `EfficientNetv2` does not currently support pretrained weights. See also [`efficientnet`](#). diff --git a/src/convnets/inceptions/googlenet.jl b/src/convnets/inceptions/googlenet.jl index a84af56c4..afa8fa52b 100644 --- a/src/convnets/inceptions/googlenet.jl +++ b/src/convnets/inceptions/googlenet.jl @@ -86,7 +86,7 @@ Create an Inception-v1 model (commonly referred to as `GoogLeNet`) - `bias`: set to `true` to use bias in the convolution layers !!! warning - + `GoogLeNet` does not currently support pretrained weights. See also [`Metalhead.googlenet`](@ref). diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index 98d686062..f95e9d062 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -109,7 +109,7 @@ Creates an InceptionResNetv2 model. - `nclasses`: the number of output classes. !!! warning - + `InceptionResNetv2` does not currently support pretrained weights. See also [`Metalhead.inceptionresnetv2`](@ref). diff --git a/src/convnets/inceptions/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl index 41d7ae18e..577a90dd2 100644 --- a/src/convnets/inceptions/inceptionv3.jl +++ b/src/convnets/inceptions/inceptionv3.jl @@ -170,7 +170,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes !!! warning - + `Inceptionv3` does not currently support pretrained weights. See also [`Metalhead.inceptionv3`](@ref). diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index 964afc362..54b9b7bb1 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -124,7 +124,7 @@ Creates an Inceptionv4 model. - `nclasses`: the number of output classes. !!! warning - + `Inceptionv4` does not currently support pretrained weights. See also [`Metalhead.inceptionv4`](@ref). diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 9dfd73f86..52bd221a6 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -80,7 +80,7 @@ Creates an Xception model. - `nclasses`: the number of output classes. !!! warning - + `Xception` does not currently support pretrained weights. See also [`Metalhead.xception`](@ref). diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index 98cd9d759..d0be90a8c 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -86,7 +86,7 @@ Creates a MNASNet model with the specified configuration. - `nclasses`: The number of output classes !!! warning - + `MNASNet` does not currently support pretrained weights. See also [`Metalhead.mnasnet`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index ca17dc4ac..881a1e533 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -52,7 +52,7 @@ Create a MobileNetv1 model with the baseline configuration - `nclasses`: The number of output classes !!! warning - + `MobileNetv1` does not currently support pretrained weights. See also [`Metalhead.mobilenetv1`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 6d0973130..c028fc50a 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -58,7 +58,7 @@ Create a MobileNetv2 model with the specified configuration. - `nclasses`: The number of output classes !!! warning - + `MobileNetv2` does not currently support pretrained weights. See also [`Metalhead.mobilenetv2`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 07e4501eb..2fbe71532 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -78,7 +78,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `nclasses`: the number of output classes !!! warning - + `MobileNetv3` does not currently support pretrained weights. See also [`Metalhead.mobilenetv3`](@ref). diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index c5b2f1c7d..829d15f39 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -191,7 +191,7 @@ If `outplanes > inplanes`, it maps the input to `outplanes` channels using a 1x1 layer and zero padding. !!! warning - + This does not currently support the scenario where `inplanes > outplanes`. # Arguments @@ -237,7 +237,7 @@ on how to use this function. # Arguments - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. - + + `:default`: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2. diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 33b9fb961..43518addc 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -17,7 +17,7 @@ Creates a Res2Net model with the specified depth, scale, and base width. - `nclasses`: the number of output classes !!! warning - + `Res2Net` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). @@ -64,7 +64,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina - `nclasses`: the number of output classes !!! warning - + `Res2NeXt` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 6d65708d0..5712a7079 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -11,7 +11,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: - + + depth 50, cardinality of 32 and base width of 4. + depth 101, cardinality of 32 and base width of 8. + depth 101, cardinality of 64 and base width of 4. diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 44e32083d..0eb7e7128 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -12,7 +12,7 @@ Creates a SEResNet model with the specified depth. - `nclasses`: the number of output classes !!! warning - + `SEResNet` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). @@ -55,7 +55,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. - `nclasses`: the number of output classes !!! warning - + `SEResNeXt` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index cf0261cef..ad412531f 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,81 +1,83 @@ function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) - Chain(conv1 = Conv(kernel, in_chs => out_chs, pad = (1, 1); init = _random_normal), - norm1 = BatchNorm(out_chs, relu), - conv2 = Conv(kernel, out_chs => out_chs, pad = (1, 1); init = _random_normal), - norm2 = BatchNorm(out_chs, relu)) + return Chain(; + conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1), + init = _random_normal), + norm1 = BatchNorm(out_chs, relu), + conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1), + init = _random_normal), + norm2 = BatchNorm(out_chs, relu)) end -function UpConvBlock(in_chs::Int, out_chs::Int, kernel = (2, 2)) - Chain(convtranspose = ConvTranspose(kernel, in_chs => out_chs, stride = (2, 2); init = _random_normal), - norm = BatchNorm(out_chs, relu)) +function upconv_block(in_chs::Int, out_chs::Int, kernel = (2, 2)) + return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2), init = _random_normal) end """ - UNet(inplanes::Integer = 3, outplanes::Integer = 1, init_features::Integer = 32) + UNet(inplanes::Integer = 3, outplanes::Integer = 1, init_features::Integer = 32) - Create a UNet model - ([reference](https://arxiv.org/abs/1505.04597v1)) + Create a UNet model + ([reference](https://arxiv.org/abs/1505.04597v1)) - # Arguments - - `in_channels`: The number of input channels - - `inplanes`: The number of input planes to the network - - `outplanes`: The number of output features + # Arguments + - `in_channels`: The number of input channels + - `inplanes`: The number of input planes to the network + - `outplanes`: The number of output features !!! warning - - `UNet` does not currently support pretrained weights. + + `UNet` does not currently support pretrained weights. """ struct UNet - encoder::Any - decoder::Any - upconv::Any - pool::Any - bottleneck::Any - final_conv::Any + encoder::Any + decoder::Any + upconv::Any + pool::Any + bottleneck::Any + final_conv::Any end @functor UNet -function UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = 1) - - features = inplanes +function UNet(in_channels::Integer = 3, inplanes::Integer = 32, + outplanes::Integer = inplanes) + features = inplanes - encoder_layers = [] - append!(encoder_layers, [unet_block(in_channels, features)]) - append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i ∈ 0:2]) + encoder_layers = [] + append!(encoder_layers, [unet_block(in_channels, features)]) + append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) - encoder = Chain(encoder_layers) + encoder = Chain(encoder_layers) - decoder = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i ∈ 0:3]) + decoder = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 0:3]) - pool = Chain([MaxPool((2, 2), stride = (2, 2)) for _ ∈ 1:4]) + pool = Chain([MaxPool((2, 2); stride = (2, 2)) for _ in 1:4]) - upconv = Chain([UpConvBlock(features * 2^(i + 1), features * 2^i) for i ∈ 3:-1:0]) + upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]) - bottleneck = _block(features * 8, features * 16) + bottleneck = unet_block(features * 8, features * 16) - final_conv = Conv((1, 1), features => outplanes) + final_conv = Conv((1, 1), features => outplanes) - UNet(encoder, decoder, upconv, pool, bottleneck, final_conv) + return UNet(encoder, decoder, upconv, pool, bottleneck, final_conv) end function (u::UNet)(x::AbstractArray) - enc_out = [] + enc_out = [] - out = x - for i ∈ 1:4 - out = u.encoder[i](out) - push!(enc_out, out) + out = x + for i in 1:4 + out = u.encoder[i](out) + push!(enc_out, out) - out = u.pool[i](out) - end + out = u.pool[i](out) + end - out = u.bottleneck(out) + out = u.bottleneck(out) - for i ∈ 4:-1:1 - out = u.upconv[5-i](out) - out = cat(out, enc_out[i], dims = 3) - out = u.decoder[i](out) - end + for i in 4:-1:1 + out = u.upconv[5 - i](out) + out = cat(out, enc_out[i]; dims = 3) + out = u.decoder[i](out) + end - σ(u.final_conv(out)) + return σ(u.final_conv(out)) end diff --git a/src/utilities.jl b/src/utilities.jl index 9f54107fd..0c35758e5 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -79,3 +79,9 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) function _checkconfig(config, configs) @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end + +# Utility function to be used for generating random normal distribution used for +# initializing layer parameters +function _random_normal(shape...) + return Float64.(rand(Normal(0.0f0, 0.02f0), shape...)) +end diff --git a/test/convnets.jl b/test/convnets.jl index af15ff413..13806288b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -6,19 +6,19 @@ _gc() end -@testset "VGG" begin - @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] - m = VGG(sz, batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (VGG, sz, bn) in PRETRAINED_MODELS - @test acctest(VGG(sz, batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], + bn in [true, false] + + m = VGG(sz; batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (VGG, sz, bn) in PRETRAINED_MODELS + @test acctest(VGG(sz; batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "ResNet" begin # Tests for pretrained ResNets @@ -26,100 +26,94 @@ end m = ResNet(sz) @test size(m(x_224)) == (1000, 1) if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) + @test acctest(ResNet(sz; pretrain = true)) else @test_throws ArgumentError ResNet(sz, pretrain = true) end end - @testset "resnet" begin - @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] - layer_list = [ - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3] + @testset "resnet" begin @testset for block_fn in [ + Metalhead.basicblock, + Metalhead.bottleneck, + ] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3], + ] + @testset for layers in layer_list + drop_list = [ + (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), + (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), ] - @testset for layers in layer_list - drop_list = [ - (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), - (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), - (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), - ] - @testset for drop_probs in drop_list - m = Metalhead.resnet(block_fn, layers; drop_probs...) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - end + @testset for drop_probs in drop_list + m = Metalhead.resnet(block_fn, layers; drop_probs...) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() end end - end + end end - @testset "WideResNet" begin - @testset "WideResNet($sz)" for sz in [50, 101] - m = WideResNet(sz) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - if (WideResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) - else - @test_throws ArgumentError WideResNet(sz, pretrain = true) - end + @testset "WideResNet" begin @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz; pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) end - end + end end end -@testset "ResNeXt" begin - @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = ResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "ResNeXt" begin @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, + pretrain = true) end + @test gradtest(m, x_224) + _gc() end end -end +end end -@testset "SEResNet" begin - @testset for depth in [18, 34, 50, 101, 152] - m = SEResNet(depth) - @test size(m(x_224)) == (1000, 1) - if (SEResNet, depth) in PRETRAINED_MODELS - @test acctest(SEResNet(depth, pretrain = true)) - else - @test_throws ArgumentError SEResNet(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "SEResNet" begin @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if (SEResNet, depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth; pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "SEResNeXt" begin - @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = SEResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(SEResNeXt(depth, pretrain = true)) - else - @test_throws ArgumentError SEResNeXt(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "SEResNeXt" begin @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth; pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) end + @test gradtest(m, x_224) + _gc() end end -end +end end @testset "Res2Net" begin @testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)] @@ -146,64 +140,63 @@ end end end -@testset "Res2NeXt" begin - @testset for depth in [50, 101] - m = Res2NeXt(depth) - @test size(m(x_224)) == (1000, 1) - if (Res2NeXt, depth) in PRETRAINED_MODELS - @test acctest(Res2NeXt(depth, pretrain = true)) - else - @test_throws ArgumentError Res2NeXt(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "Res2NeXt" begin @testset for depth in [50, 101] + m = Res2NeXt(depth) + @test size(m(x_224)) == (1000, 1) + if (Res2NeXt, depth) in PRETRAINED_MODELS + @test acctest(Res2NeXt(depth; pretrain = true)) + else + @test_throws ArgumentError Res2NeXt(depth, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] - # preferred image resolution scaling - r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] - x = rand(Float32, r, r, 3, 1) - m = EfficientNet(config) - @test size(m(x)) == (1000, 1) - if (EfficientNet, config) in PRETRAINED_MODELS - @test acctest(EfficientNet(config, pretrain = true)) - else - @test_throws ArgumentError EfficientNet(config, pretrain = true) - end - @test gradtest(m, x) - _gc() +@testset "EfficientNet" begin @testset "EfficientNet($config)" for config in [ + :b0, + :b1, + :b2, + :b3, + :b4, + :b5, +] #:b6, :b7, :b8] + # preferred image resolution scaling + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] + x = rand(Float32, r, r, 3, 1) + m = EfficientNet(config) + @test size(m(x)) == (1000, 1) + if (EfficientNet, config) in PRETRAINED_MODELS + @test acctest(EfficientNet(config; pretrain = true)) + else + @test_throws ArgumentError EfficientNet(config, pretrain = true) end -end + @test gradtest(m, x) + _gc() +end end -@testset "EfficientNetv2" begin - @testset for config in [:small, :medium, :large] # :xlarge] - m = EfficientNetv2(config) - @test size(m(x_224)) == (1000, 1) - if (EfficientNetv2, config) in PRETRAINED_MODELS - @test acctest(EfficientNetv2(config, pretrain = true)) - else - @test_throws ArgumentError EfficientNetv2(config, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "EfficientNetv2" begin @testset for config in [:small, :medium, :large] # :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config; pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "GoogLeNet" begin - @testset for bn in [true, false] - m = GoogLeNet(batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (GoogLeNet, bn) in PRETRAINED_MODELS - @test acctest(GoogLeNet(batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "GoogLeNet" begin @testset for bn in [true, false] + m = GoogLeNet(; batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (GoogLeNet, bn) in PRETRAINED_MODELS + @test acctest(GoogLeNet(; batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @@ -211,7 +204,7 @@ end m = Inceptionv3() @test size(m(x_299)) == (1000, 2) if Inceptionv3 in PRETRAINED_MODELS - @test acctest(Inceptionv3(pretrain = true)) + @test acctest(Inceptionv3(; pretrain = true)) else @test_throws ArgumentError Inceptionv3(pretrain = true) end @@ -222,7 +215,7 @@ end m = Inceptionv4() @test size(m(x_299)) == (1000, 2) if Inceptionv4 in PRETRAINED_MODELS - @test acctest(Inceptionv4(pretrain = true)) + @test acctest(Inceptionv4(; pretrain = true)) else @test_throws ArgumentError Inceptionv4(pretrain = true) end @@ -233,7 +226,7 @@ end m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) if InceptionResNetv2 in PRETRAINED_MODELS - @test acctest(InceptionResNetv2(pretrain = true)) + @test acctest(InceptionResNetv2(; pretrain = true)) else @test_throws ArgumentError InceptionResNetv2(pretrain = true) end @@ -244,7 +237,7 @@ end m = Xception() @test size(m(x_299)) == (1000, 2) if Xception in PRETRAINED_MODELS - @test acctest(Xception(pretrain = true)) + @test acctest(Xception(; pretrain = true)) else @test_throws ArgumentError Xception(pretrain = true) end @@ -257,7 +250,7 @@ end m = SqueezeNet() @test size(m(x_224)) == (1000, 1) if SqueezeNet in PRETRAINED_MODELS - @test acctest(SqueezeNet(pretrain = true)) + @test acctest(SqueezeNet(; pretrain = true)) else @test_throws ArgumentError SqueezeNet(pretrain = true) end @@ -265,26 +258,24 @@ end _gc() end -@testset "DenseNet" begin - @testset for sz in [121, 161, 169, 201] - m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) - if (DenseNet, sz) in PRETRAINED_MODELS - @test acctest(DenseNet(sz, pretrain = true)) - else - @test_throws ArgumentError DenseNet(sz, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] + m = DenseNet(sz) + @test size(m(x_224)) == (1000, 1) + if (DenseNet, sz) in PRETRAINED_MODELS + @test acctest(DenseNet(sz; pretrain = true)) + else + @test_throws ArgumentError DenseNet(sz, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] @testset "MobileNetv1" begin m = MobileNetv1(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv1, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv1(pretrain = true)) + @test acctest(MobileNetv1(; pretrain = true)) else @test_throws ArgumentError MobileNetv1(pretrain = true) end @@ -295,55 +286,53 @@ end m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(pretrain = true)) + @test acctest(MobileNetv2(; pretrain = true)) else @test_throws ArgumentError MobileNetv2(pretrain = true) end @test gradtest(m, x_224) end _gc() - @testset "MobileNetv3" verbose = true begin - @testset for config in [:small, :large] - m = MobileNetv3(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) - else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end - end - @testset "MNASNet" verbose = true begin - @testset for config in [:A1, :B1] - m = MNASNet(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MNASNet, config, width_mult) in PRETRAINED_MODELS - @test acctest(MNASNet(config; pretrain = true)) - else - @test_throws ArgumentError MNASNet(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end - end -end - -@testset "ConvNeXt" verbose = true begin - @testset for config in [:small, :base, :large, :tiny, :xlarge] - m = ConvNeXt(config) + @testset "MobileNetv3" verbose=true begin @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) @test size(m(x_224)) == (1000, 1) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) + else + @test_throws ArgumentError MobileNetv3(config; pretrain = true) + end @test gradtest(m, x_224) _gc() - end -end - -@testset "ConvMixer" verbose = true begin - @testset for config in [:small, :base, :large] - m = ConvMixer(config) + end end + @testset "MNASNet" verbose=true begin @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end @test gradtest(m, x_224) _gc() - end + end end end + +@testset "ConvNeXt" verbose=true begin @testset for config in [ + :small, + :base, + :large, + :tiny, + :xlarge, +] + m = ConvNeXt(config) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() +end end + +@testset "ConvMixer" verbose=true begin @testset for config in [:small, :base, :large] + m = ConvMixer(config) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() +end end diff --git a/test/runtests.jl b/test/runtests.jl index cd5c8ab99..d5af101e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,12 +18,12 @@ const PRETRAINED_MODELS = [ (WideResNet, 101), (ResNeXt, 50, 32, 4), (ResNeXt, 101, 64, 4), - (ResNeXt, 101, 32, 8) + (ResNeXt, 101, 32, 8), ] function _gc() GC.safepoint() - GC.gc(true) + return GC.gc(true) end function gradtest(model, input) @@ -43,7 +43,8 @@ end const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224)) # CHW -> WHC -const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet +const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> + normalize_imagenet # ImageNet labels const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) @@ -58,16 +59,10 @@ x_224 = rand(Float32, 224, 224, 3, 1) x_256 = rand(Float32, 256, 256, 3, 1) # CNN tests -@testset verbose = true "ConvNets" begin - include("convnets.jl") -end +@testset verbose=true "ConvNets" begin include("convnets.jl") end # Mixer tests -@testset verbose = true "Mixers" begin - include("mixers.jl") -end +@testset verbose=true "Mixers" begin include("mixers.jl") end # ViT tests -@testset verbose = true "ViTs" begin - include("vits.jl") -end +@testset verbose=true "ViTs" begin include("vits.jl") end diff --git a/test/vits.jl b/test/vits.jl index 7561cfdb5..3c737210c 100644 --- a/test/vits.jl +++ b/test/vits.jl @@ -1,8 +1,6 @@ -@testset "ViT" begin - for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] - m = ViT(config) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) - _gc() - end -end +@testset "ViT" begin for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] + m = ViT(config) + @test size(m(x_256)) == (1000, 1) + @test gradtest(m, x_256) + _gc() +end end From 552a8fd13e4e5794cee7f710b4c788fdd3ed9b64 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 1 Jan 2023 21:21:13 +0530 Subject: [PATCH 04/29] removed custom forward pass using Parallel --- src/convnets/unet.jl | 101 +++++++++++++++++++++---------------------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index ad412531f..cff8b9ea1 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,15 +1,56 @@ function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) return Chain(; - conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1), - init = _random_normal), + conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), norm1 = BatchNorm(out_chs, relu), - conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1), - init = _random_normal), + conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)), norm2 = BatchNorm(out_chs, relu)) end function upconv_block(in_chs::Int, out_chs::Int, kernel = (2, 2)) - return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2), init = _random_normal) + return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2)) +end + +function cat_fn(layers...) + return cat(layers...; dims = 3) +end + +function unet(in_channels::Integer = 3, out_channels::Integer = in_channels, + features::Integer = 32) + encoder_conv_layers = [] + append!(encoder_conv_layers, + [unet_block(in_channels, features)]) + + append!(encoder_conv_layers, + [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) + + encoder_conv = Chain(encoder_conv_layers) + encoder_pool = [Chain(encoder_conv[i], MaxPool((2, 2); stride = (2, 2))) for i in 1:4] + + bottleneck = unet_block(features * 8, features * 16) + layers = Chain(encoder_conv, bottleneck) + + upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) + for i in 0:3]) + + concat_layer = Chain([Parallel(cat_fn, + encoder_pool[i], + upconv[i]) + for i in 1:4]) + + decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]) + + layers = Chain(layers, decoder_layer) + + decoder = Chain([Chain([ + concat_layer[i], + decoder_layer[5 - i]]) + for i in 4:-1:1]) + + final_conv = Conv((1, 1), features => out_channels, σ) + + decoder = Chain(decoder, final_conv) + + return layers end """ @@ -28,56 +69,14 @@ end `UNet` does not currently support pretrained weights. """ struct UNet - encoder::Any - decoder::Any - upconv::Any - pool::Any - bottleneck::Any - final_conv::Any + layers::Any end @functor UNet function UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = inplanes) - features = inplanes - - encoder_layers = [] - append!(encoder_layers, [unet_block(in_channels, features)]) - append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) - - encoder = Chain(encoder_layers) - - decoder = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 0:3]) - - pool = Chain([MaxPool((2, 2); stride = (2, 2)) for _ in 1:4]) - - upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]) - - bottleneck = unet_block(features * 8, features * 16) - - final_conv = Conv((1, 1), features => outplanes) - - return UNet(encoder, decoder, upconv, pool, bottleneck, final_conv) + layers = unet(in_channels, inplanes, outplanes) + return UNet(layers) end -function (u::UNet)(x::AbstractArray) - enc_out = [] - - out = x - for i in 1:4 - out = u.encoder[i](out) - push!(enc_out, out) - - out = u.pool[i](out) - end - - out = u.bottleneck(out) - - for i in 4:-1:1 - out = u.upconv[5 - i](out) - out = cat(out, enc_out[i]; dims = 3) - out = u.decoder[i](out) - end - - return σ(u.final_conv(out)) -end +(m::UNet)(x::AbstractArray) = m.layers(x) From c577aed6da0749408edd351efcd3854ae5610c04 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 1 Jan 2023 21:30:57 +0530 Subject: [PATCH 05/29] removing _random_normal --- src/utilities.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/utilities.jl b/src/utilities.jl index 0c35758e5..9f54107fd 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -79,9 +79,3 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) function _checkconfig(config, configs) @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end - -# Utility function to be used for generating random normal distribution used for -# initializing layer parameters -function _random_normal(shape...) - return Float64.(rand(Normal(0.0f0, 0.02f0), shape...)) -end From fb642c49a17c837b73555240a63e9e87f0af280b Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Mon, 2 Jan 2023 10:53:29 +0530 Subject: [PATCH 06/29] incorporating suggested changes --- src/Metalhead.jl | 1 - src/convnets/unet.jl | 89 +++++++++++++++++++++----------------------- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 80da39bb8..e7c5421e4 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -2,7 +2,6 @@ module Metalhead using Flux using Flux: Zygote, outputsize -using Distributions: Normal using Functors using BSON using Artifacts, LazyArtifacts diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index cff8b9ea1..75af1fc0d 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,82 +1,77 @@ function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) - return Chain(; - conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), - norm1 = BatchNorm(out_chs, relu), - conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)), - norm2 = BatchNorm(out_chs, relu)) + return Chain(; conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), + norm1 = BatchNorm(out_chs, relu), + conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)), + norm2 = BatchNorm(out_chs, relu)) end function upconv_block(in_chs::Int, out_chs::Int, kernel = (2, 2)) - return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2)) -end - -function cat_fn(layers...) - return cat(layers...; dims = 3) + return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2)) end function unet(in_channels::Integer = 3, out_channels::Integer = in_channels, - features::Integer = 32) - encoder_conv_layers = [] - append!(encoder_conv_layers, - [unet_block(in_channels, features)]) + features::Integer = 32) + encoder_conv = [] + push!(encoder_conv, + unet_block(in_channels, features)) - append!(encoder_conv_layers, - [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) + append!(encoder_conv, + [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) - encoder_conv = Chain(encoder_conv_layers) - encoder_pool = [Chain(encoder_conv[i], MaxPool((2, 2); stride = (2, 2))) for i in 1:4] + encoder_conv = Chain(encoder_conv) + encoder_pool = [Chain(encoder_conv[i], MaxPool((2, 2); stride = (2, 2))) for i in 1:4] - bottleneck = unet_block(features * 8, features * 16) - layers = Chain(encoder_conv, bottleneck) + bottleneck = unet_block(features * 8, features * 16) + layers = Chain(encoder_conv, bottleneck) - upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) - for i in 0:3]) + upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) + for i in 0:3]...) - concat_layer = Chain([Parallel(cat_fn, - encoder_pool[i], - upconv[i]) - for i in 1:4]) + concat_layer = Chain([Parallel(cat_channels, + encoder_pool[i], + upconv[i]) + for i in 1:4]...) - decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]) + decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]...) - layers = Chain(layers, decoder_layer) + layers = Chain(layers, decoder_layer) - decoder = Chain([Chain([ - concat_layer[i], - decoder_layer[5 - i]]) - for i in 4:-1:1]) + decoder = Chain([Chain([ + concat_layer[i], + decoder_layer[5-i]]) + for i in 4:-1:1]...) - final_conv = Conv((1, 1), features => out_channels, σ) + final_conv = Conv((1, 1), features => out_channels, σ) - decoder = Chain(decoder, final_conv) + decoder = Chain(decoder, final_conv) - return layers + return layers end """ - UNet(inplanes::Integer = 3, outplanes::Integer = 1, init_features::Integer = 32) + UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = inplanes) - Create a UNet model - ([reference](https://arxiv.org/abs/1505.04597v1)) + Create a UNet model + ([reference](https://arxiv.org/abs/1505.04597v1)) - # Arguments - - `in_channels`: The number of input channels - - `inplanes`: The number of input planes to the network - - `outplanes`: The number of output features + # Arguments + - `in_channels`: The number of input channels + - `inplanes`: The number of input features to the network + - `outplanes`: The number of output features !!! warning - `UNet` does not currently support pretrained weights. + `UNet` does not currently support pretrained weights. """ struct UNet - layers::Any + layers::Any end @functor UNet function UNet(in_channels::Integer = 3, inplanes::Integer = 32, - outplanes::Integer = inplanes) - layers = unet(in_channels, inplanes, outplanes) - return UNet(layers) + outplanes::Integer = inplanes) + layers = unet(in_channels, inplanes, outplanes) + return UNet(layers) end (m::UNet)(x::AbstractArray) = m.layers(x) From 7c7b1eeb6b2d0f4a05bf6a192a3ef6de17157235 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 3 Jan 2023 10:58:39 +0530 Subject: [PATCH 07/29] Revert "ran juliaformatter" This reverts commit ca73586da0da468d244ac768d4aec3b400a04061. --- docs/make.jl | 42 +-- src/Metalhead.jl | 8 +- src/convnets/alexnet.jl | 2 +- src/convnets/builders/resnet.jl | 2 +- src/convnets/convnext.jl | 2 +- src/convnets/densenet.jl | 2 +- src/convnets/efficientnets/efficientnet.jl | 2 +- src/convnets/efficientnets/efficientnetv2.jl | 2 +- src/convnets/inceptions/googlenet.jl | 2 +- src/convnets/inceptions/inceptionresnetv2.jl | 2 +- src/convnets/inceptions/inceptionv3.jl | 2 +- src/convnets/inceptions/inceptionv4.jl | 2 +- src/convnets/inceptions/xception.jl | 2 +- src/convnets/mobilenets/mnasnet.jl | 2 +- src/convnets/mobilenets/mobilenetv1.jl | 2 +- src/convnets/mobilenets/mobilenetv2.jl | 2 +- src/convnets/mobilenets/mobilenetv3.jl | 2 +- src/convnets/resnets/core.jl | 4 +- src/convnets/resnets/res2net.jl | 4 +- src/convnets/resnets/resnext.jl | 2 +- src/convnets/resnets/seresnet.jl | 4 +- src/convnets/unet.jl | 6 +- test/convnets.jl | 375 ++++++++++--------- test/runtests.jl | 19 +- test/vits.jl | 14 +- 25 files changed, 256 insertions(+), 252 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 7074db5fd..72fb486c9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,34 +1,26 @@ -using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, - DataAugmentation, Flux +using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux DocMeta.setdocmeta!(Metalhead, :DocTestSetup, :(using Metalhead); recursive = true) -makedocs(; - modules = [ - Metalhead, - Artifacts, - LazyArtifacts, - Images, - OneHotArrays, - DataAugmentation, - Flux, - ], +makedocs(modules = [Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux], sitename = "Metalhead.jl", doctest = false, pages = ["Home" => "index.md", - "Tutorials" => [ - "tutorials/quickstart.md", - ], - "Developer guide" => "contributing.md", - "API reference" => [ - "api/reference.md", - ], - ], - format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/", - # analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true")) + "Tutorials" => [ + "tutorials/quickstart.md", + ], + "Developer guide" => "contributing.md", + "API reference" => [ + "api/reference.md", + ], + ], + format = Documenter.HTML( + canonical = "https://fluxml.ai/Metalhead.jl/stable/", + # analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true"), + ) -deploydocs(; repo = "github.com/FluxML/Metalhead.jl.git", +deploydocs(repo = "github.com/FluxML/Metalhead.jl.git", target = "build", push_preview = true) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index e7c5421e4..b67bf1faf 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -10,7 +10,7 @@ using MLUtils using PartialFunctions using Random -using Functors: Functors +import Functors include("utilities.jl") @@ -28,8 +28,6 @@ include("convnets/builders/stages.jl") ## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") -## Unet -include("convnets/unet.jl") ## ResNets include("convnets/resnets/core.jl") include("convnets/resnets/res2net.jl") @@ -68,7 +66,7 @@ include("vit-based/vit.jl") # Load pretrained weights include("pretrain.jl") -export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, UNet, +export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, @@ -78,7 +76,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, UNet, MLPMixer, ResMLP, gMLP, ViT # use Flux._big_show to pretty print large models -for T in (:AlexNet, :VGG, :UNet, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, +for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index d05919f24..fba6749e4 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -42,7 +42,7 @@ Create a `AlexNet`. - `nclasses`: the number of output classes !!! warning - + `AlexNet` does not currently support pretrained weights. See also [`alexnet`](@ref). diff --git a/src/convnets/builders/resnet.jl b/src/convnets/builders/resnet.jl index 231b9ef30..580baaa34 100644 --- a/src/convnets/builders/resnet.jl +++ b/src/convnets/builders/resnet.jl @@ -5,7 +5,7 @@ Creates a generic ResNet-like model. !!! info - + This is a very generic, flexible but low level function that can be used to create any of the ResNet variants. For a more user friendly function, see [`Metalhead.resnet`](@ref). diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 2c76d6a30..a38840d0b 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -97,7 +97,7 @@ Creates a ConvNeXt model. - `nclasses`: number of output classes !!! warning - + `ConvNeXt` does not currently support pretrained weights. See also [`Metalhead.convnext`](@ref). diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 5b340a8ca..4799e22a7 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -122,7 +122,7 @@ Create a DenseNet model with specified configuration. Currently supported values Set `pretrain = true` to load the model with pre-trained weights for ImageNet. !!! warning - + `DenseNet` does not currently support pretrained weights. See also [`Metalhead.densenet`](@ref). diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 53f5c3569..2657a3884 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -73,7 +73,7 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). - `nclasses`: number of output classes. !!! warning - + EfficientNet does not currently support pretrained weights. See also [`Metalhead.efficientnet`](@ref). diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index 6b89eb0ea..dab159e68 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -76,7 +76,7 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - `nclasses`: number of output classes !!! warning - + `EfficientNetv2` does not currently support pretrained weights. See also [`efficientnet`](#). diff --git a/src/convnets/inceptions/googlenet.jl b/src/convnets/inceptions/googlenet.jl index afa8fa52b..a84af56c4 100644 --- a/src/convnets/inceptions/googlenet.jl +++ b/src/convnets/inceptions/googlenet.jl @@ -86,7 +86,7 @@ Create an Inception-v1 model (commonly referred to as `GoogLeNet`) - `bias`: set to `true` to use bias in the convolution layers !!! warning - + `GoogLeNet` does not currently support pretrained weights. See also [`Metalhead.googlenet`](@ref). diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index f95e9d062..98d686062 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -109,7 +109,7 @@ Creates an InceptionResNetv2 model. - `nclasses`: the number of output classes. !!! warning - + `InceptionResNetv2` does not currently support pretrained weights. See also [`Metalhead.inceptionresnetv2`](@ref). diff --git a/src/convnets/inceptions/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl index 577a90dd2..41d7ae18e 100644 --- a/src/convnets/inceptions/inceptionv3.jl +++ b/src/convnets/inceptions/inceptionv3.jl @@ -170,7 +170,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). - `nclasses`: the number of output classes !!! warning - + `Inceptionv3` does not currently support pretrained weights. See also [`Metalhead.inceptionv3`](@ref). diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index 54b9b7bb1..964afc362 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -124,7 +124,7 @@ Creates an Inceptionv4 model. - `nclasses`: the number of output classes. !!! warning - + `Inceptionv4` does not currently support pretrained weights. See also [`Metalhead.inceptionv4`](@ref). diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 52bd221a6..9dfd73f86 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -80,7 +80,7 @@ Creates an Xception model. - `nclasses`: the number of output classes. !!! warning - + `Xception` does not currently support pretrained weights. See also [`Metalhead.xception`](@ref). diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index d0be90a8c..98cd9d759 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -86,7 +86,7 @@ Creates a MNASNet model with the specified configuration. - `nclasses`: The number of output classes !!! warning - + `MNASNet` does not currently support pretrained weights. See also [`Metalhead.mnasnet`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 881a1e533..ca17dc4ac 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -52,7 +52,7 @@ Create a MobileNetv1 model with the baseline configuration - `nclasses`: The number of output classes !!! warning - + `MobileNetv1` does not currently support pretrained weights. See also [`Metalhead.mobilenetv1`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index c028fc50a..6d0973130 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -58,7 +58,7 @@ Create a MobileNetv2 model with the specified configuration. - `nclasses`: The number of output classes !!! warning - + `MobileNetv2` does not currently support pretrained weights. See also [`Metalhead.mobilenetv2`](@ref). diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 2fbe71532..07e4501eb 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -78,7 +78,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `nclasses`: the number of output classes !!! warning - + `MobileNetv3` does not currently support pretrained weights. See also [`Metalhead.mobilenetv3`](@ref). diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index 829d15f39..c5b2f1c7d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -191,7 +191,7 @@ If `outplanes > inplanes`, it maps the input to `outplanes` channels using a 1x1 layer and zero padding. !!! warning - + This does not currently support the scenario where `inplanes > outplanes`. # Arguments @@ -237,7 +237,7 @@ on how to use this function. # Arguments - `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`. - + + `:default`: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2. diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index 43518addc..33b9fb961 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -17,7 +17,7 @@ Creates a Res2Net model with the specified depth, scale, and base width. - `nclasses`: the number of output classes !!! warning - + `Res2Net` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). @@ -64,7 +64,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina - `nclasses`: the number of output classes !!! warning - + `Res2NeXt` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 5712a7079..6d65708d0 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -11,7 +11,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. - `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet. Supported configurations are: - + + depth 50, cardinality of 32 and base width of 4. + depth 101, cardinality of 32 and base width of 8. + depth 101, cardinality of 64 and base width of 4. diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 0eb7e7128..44e32083d 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -12,7 +12,7 @@ Creates a SEResNet model with the specified depth. - `nclasses`: the number of output classes !!! warning - + `SEResNet` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). @@ -55,7 +55,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. - `nclasses`: the number of output classes !!! warning - + `SEResNeXt` does not currently support pretrained weights. Advanced users who want more configuration options will be better served by using [`resnet`](@ref). diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 75af1fc0d..99e9f4809 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,5 +1,5 @@ function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) - return Chain(; conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), + return Chain(conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), norm1 = BatchNorm(out_chs, relu), conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)), norm2 = BatchNorm(out_chs, relu)) @@ -50,17 +50,13 @@ end """ UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = inplanes) - Create a UNet model ([reference](https://arxiv.org/abs/1505.04597v1)) - # Arguments - `in_channels`: The number of input channels - `inplanes`: The number of input features to the network - `outplanes`: The number of output features - !!! warning - `UNet` does not currently support pretrained weights. """ struct UNet diff --git a/test/convnets.jl b/test/convnets.jl index 13806288b..af15ff413 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -6,19 +6,19 @@ _gc() end -@testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], - bn in [true, false] - - m = VGG(sz; batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (VGG, sz, bn) in PRETRAINED_MODELS - @test acctest(VGG(sz; batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) +@testset "VGG" begin + @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] + m = VGG(sz, batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (VGG, sz, bn) in PRETRAINED_MODELS + @test acctest(VGG(sz, batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "ResNet" begin # Tests for pretrained ResNets @@ -26,94 +26,100 @@ end end m = ResNet(sz) @test size(m(x_224)) == (1000, 1) if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz; pretrain = true)) + @test acctest(ResNet(sz, pretrain = true)) else @test_throws ArgumentError ResNet(sz, pretrain = true) end end - @testset "resnet" begin @testset for block_fn in [ - Metalhead.basicblock, - Metalhead.bottleneck, - ] - layer_list = [ - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3], - ] - @testset for layers in layer_list - drop_list = [ - (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), - (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), - (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), + @testset "resnet" begin + @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3] ] - @testset for drop_probs in drop_list - m = Metalhead.resnet(block_fn, layers; drop_probs...) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() + @testset for layers in layer_list + drop_list = [ + (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), + (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), + ] + @testset for drop_probs in drop_list + m = Metalhead.resnet(block_fn, layers; drop_probs...) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end end end - end end + end - @testset "WideResNet" begin @testset "WideResNet($sz)" for sz in [50, 101] - m = WideResNet(sz) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - if (WideResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz; pretrain = true)) - else - @test_throws ArgumentError WideResNet(sz, pretrain = true) + @testset "WideResNet" begin + @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz, pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) + end end - end end + end end -@testset "ResNeXt" begin @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = ResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, - pretrain = true) +@testset "ResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() end end -end end +end -@testset "SEResNet" begin @testset for depth in [18, 34, 50, 101, 152] - m = SEResNet(depth) - @test size(m(x_224)) == (1000, 1) - if (SEResNet, depth) in PRETRAINED_MODELS - @test acctest(SEResNet(depth; pretrain = true)) - else - @test_throws ArgumentError SEResNet(depth, pretrain = true) +@testset "SEResNet" begin + @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if (SEResNet, depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "SEResNeXt" begin @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = SEResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(SEResNeXt(depth; pretrain = true)) - else - @test_throws ArgumentError SEResNeXt(depth, pretrain = true) +@testset "SEResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() end end -end end +end @testset "Res2Net" begin @testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)] @@ -140,63 +146,64 @@ end end end end -@testset "Res2NeXt" begin @testset for depth in [50, 101] - m = Res2NeXt(depth) - @test size(m(x_224)) == (1000, 1) - if (Res2NeXt, depth) in PRETRAINED_MODELS - @test acctest(Res2NeXt(depth; pretrain = true)) - else - @test_throws ArgumentError Res2NeXt(depth, pretrain = true) +@testset "Res2NeXt" begin + @testset for depth in [50, 101] + m = Res2NeXt(depth) + @test size(m(x_224)) == (1000, 1) + if (Res2NeXt, depth) in PRETRAINED_MODELS + @test acctest(Res2NeXt(depth, pretrain = true)) + else + @test_throws ArgumentError Res2NeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "EfficientNet" begin @testset "EfficientNet($config)" for config in [ - :b0, - :b1, - :b2, - :b3, - :b4, - :b5, -] #:b6, :b7, :b8] - # preferred image resolution scaling - r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] - x = rand(Float32, r, r, 3, 1) - m = EfficientNet(config) - @test size(m(x)) == (1000, 1) - if (EfficientNet, config) in PRETRAINED_MODELS - @test acctest(EfficientNet(config; pretrain = true)) - else - @test_throws ArgumentError EfficientNet(config, pretrain = true) +@testset "EfficientNet" begin + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] + # preferred image resolution scaling + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] + x = rand(Float32, r, r, 3, 1) + m = EfficientNet(config) + @test size(m(x)) == (1000, 1) + if (EfficientNet, config) in PRETRAINED_MODELS + @test acctest(EfficientNet(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNet(config, pretrain = true) + end + @test gradtest(m, x) + _gc() end - @test gradtest(m, x) - _gc() -end end +end -@testset "EfficientNetv2" begin @testset for config in [:small, :medium, :large] # :xlarge] - m = EfficientNetv2(config) - @test size(m(x_224)) == (1000, 1) - if (EfficientNetv2, config) in PRETRAINED_MODELS - @test acctest(EfficientNetv2(config; pretrain = true)) - else - @test_throws ArgumentError EfficientNetv2(config, pretrain = true) +@testset "EfficientNetv2" begin + @testset for config in [:small, :medium, :large] # :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "GoogLeNet" begin @testset for bn in [true, false] - m = GoogLeNet(; batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (GoogLeNet, bn) in PRETRAINED_MODELS - @test acctest(GoogLeNet(; batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) +@testset "GoogLeNet" begin + @testset for bn in [true, false] + m = GoogLeNet(batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (GoogLeNet, bn) in PRETRAINED_MODELS + @test acctest(GoogLeNet(batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @@ -204,7 +211,7 @@ end end m = Inceptionv3() @test size(m(x_299)) == (1000, 2) if Inceptionv3 in PRETRAINED_MODELS - @test acctest(Inceptionv3(; pretrain = true)) + @test acctest(Inceptionv3(pretrain = true)) else @test_throws ArgumentError Inceptionv3(pretrain = true) end @@ -215,7 +222,7 @@ end end m = Inceptionv4() @test size(m(x_299)) == (1000, 2) if Inceptionv4 in PRETRAINED_MODELS - @test acctest(Inceptionv4(; pretrain = true)) + @test acctest(Inceptionv4(pretrain = true)) else @test_throws ArgumentError Inceptionv4(pretrain = true) end @@ -226,7 +233,7 @@ end end m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) if InceptionResNetv2 in PRETRAINED_MODELS - @test acctest(InceptionResNetv2(; pretrain = true)) + @test acctest(InceptionResNetv2(pretrain = true)) else @test_throws ArgumentError InceptionResNetv2(pretrain = true) end @@ -237,7 +244,7 @@ end end m = Xception() @test size(m(x_299)) == (1000, 2) if Xception in PRETRAINED_MODELS - @test acctest(Xception(; pretrain = true)) + @test acctest(Xception(pretrain = true)) else @test_throws ArgumentError Xception(pretrain = true) end @@ -250,7 +257,7 @@ end m = SqueezeNet() @test size(m(x_224)) == (1000, 1) if SqueezeNet in PRETRAINED_MODELS - @test acctest(SqueezeNet(; pretrain = true)) + @test acctest(SqueezeNet(pretrain = true)) else @test_throws ArgumentError SqueezeNet(pretrain = true) end @@ -258,24 +265,26 @@ end _gc() end -@testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] - m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) - if (DenseNet, sz) in PRETRAINED_MODELS - @test acctest(DenseNet(sz; pretrain = true)) - else - @test_throws ArgumentError DenseNet(sz, pretrain = true) +@testset "DenseNet" begin + @testset for sz in [121, 161, 169, 201] + m = DenseNet(sz) + @test size(m(x_224)) == (1000, 1) + if (DenseNet, sz) in PRETRAINED_MODELS + @test acctest(DenseNet(sz, pretrain = true)) + else + @test_throws ArgumentError DenseNet(sz, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] @testset "MobileNetv1" begin m = MobileNetv1(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv1, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv1(; pretrain = true)) + @test acctest(MobileNetv1(pretrain = true)) else @test_throws ArgumentError MobileNetv1(pretrain = true) end @@ -286,53 +295,55 @@ end end m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(; pretrain = true)) + @test acctest(MobileNetv2(pretrain = true)) else @test_throws ArgumentError MobileNetv2(pretrain = true) end @test gradtest(m, x_224) end _gc() - @testset "MobileNetv3" verbose=true begin @testset for config in [:small, :large] - m = MobileNetv3(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) - else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) + @testset "MobileNetv3" verbose = true begin + @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) + else + @test_throws ArgumentError MobileNetv3(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + @testset "MNASNet" verbose = true begin + @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() end + end +end + +@testset "ConvNeXt" verbose = true begin + @testset for config in [:small, :base, :large, :tiny, :xlarge] + m = ConvNeXt(config) + @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() - end end - @testset "MNASNet" verbose=true begin @testset for config in [:A1, :B1] - m = MNASNet(config; width_mult) + end +end + +@testset "ConvMixer" verbose = true begin + @testset for config in [:small, :base, :large] + m = ConvMixer(config) @test size(m(x_224)) == (1000, 1) - if (MNASNet, config, width_mult) in PRETRAINED_MODELS - @test acctest(MNASNet(config; pretrain = true)) - else - @test_throws ArgumentError MNASNet(config; pretrain = true) - end @test gradtest(m, x_224) _gc() - end end + end end - -@testset "ConvNeXt" verbose=true begin @testset for config in [ - :small, - :base, - :large, - :tiny, - :xlarge, -] - m = ConvNeXt(config) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() -end end - -@testset "ConvMixer" verbose=true begin @testset for config in [:small, :base, :large] - m = ConvMixer(config) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() -end end diff --git a/test/runtests.jl b/test/runtests.jl index d5af101e7..cd5c8ab99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,12 +18,12 @@ const PRETRAINED_MODELS = [ (WideResNet, 101), (ResNeXt, 50, 32, 4), (ResNeXt, 101, 64, 4), - (ResNeXt, 101, 32, 8), + (ResNeXt, 101, 32, 8) ] function _gc() GC.safepoint() - return GC.gc(true) + GC.gc(true) end function gradtest(model, input) @@ -43,8 +43,7 @@ end const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224)) # CHW -> WHC -const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> - normalize_imagenet +const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet # ImageNet labels const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) @@ -59,10 +58,16 @@ x_224 = rand(Float32, 224, 224, 3, 1) x_256 = rand(Float32, 256, 256, 3, 1) # CNN tests -@testset verbose=true "ConvNets" begin include("convnets.jl") end +@testset verbose = true "ConvNets" begin + include("convnets.jl") +end # Mixer tests -@testset verbose=true "Mixers" begin include("mixers.jl") end +@testset verbose = true "Mixers" begin + include("mixers.jl") +end # ViT tests -@testset verbose=true "ViTs" begin include("vits.jl") end +@testset verbose = true "ViTs" begin + include("vits.jl") +end diff --git a/test/vits.jl b/test/vits.jl index 3c737210c..7561cfdb5 100644 --- a/test/vits.jl +++ b/test/vits.jl @@ -1,6 +1,8 @@ -@testset "ViT" begin for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] - m = ViT(config) - @test size(m(x_256)) == (1000, 1) - @test gradtest(m, x_256) - _gc() -end end +@testset "ViT" begin + for config in [:tiny, :small, :base, :large, :huge] # :giant, :gigantic] + m = ViT(config) + @test size(m(x_256)) == (1000, 1) + @test gradtest(m, x_256) + _gc() + end +end From 99f07ad919e70efa3897158cfc35c1d27e792fd2 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:34:58 +0530 Subject: [PATCH 08/29] adapting to fastai's unet impl --- src/Metalhead.jl | 5 +- src/convnets/unet.jl | 130 +++++++++++++++++++++++++------------------ src/utilities.jl | 43 +++++++------- 3 files changed, 103 insertions(+), 75 deletions(-) diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b67bf1faf..8da7230d2 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -53,6 +53,7 @@ include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/convnext.jl") include("convnets/convmixer.jl") +include("convnets/unet.jl") # Mixers include("mixers/core.jl") @@ -73,7 +74,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt, - MLPMixer, ResMLP, gMLP, ViT + MLPMixer, ResMLP, gMLP, ViT, UNet # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, @@ -81,7 +82,7 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, :EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt, - :MLPMixer, :ResMLP, :gMLP, :ViT) + :MLPMixer, :ResMLP, :gMLP, :ViT, :UNet) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 99e9f4809..cb698f4ea 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,73 +1,97 @@ -function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3)) - return Chain(conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)), - norm1 = BatchNorm(out_chs, relu), - conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)), - norm2 = BatchNorm(out_chs, relu)) +function PixelShuffleICNR(inplanes, outplanes; r = 2) + return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), + Flux.PixelShuffle(r)) end -function upconv_block(in_chs::Int, out_chs::Int, kernel = (2, 2)) - return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2)) +function UNetCombineLayer(inplanes, outplanes) + return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), + basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) end -function unet(in_channels::Integer = 3, out_channels::Integer = in_channels, - features::Integer = 32) - encoder_conv = [] - push!(encoder_conv, - unet_block(in_channels, features)) - - append!(encoder_conv, - [unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2]) - - encoder_conv = Chain(encoder_conv) - encoder_pool = [Chain(encoder_conv[i], MaxPool((2, 2); stride = (2, 2))) for i in 1:4] - - bottleneck = unet_block(features * 8, features * 16) - layers = Chain(encoder_conv, bottleneck) - - upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i) - for i in 0:3]...) +function UNetMiddleBlock(inplanes) + return Chain(basic_conv_bn((3, 3), inplanes, 2inplanes; pad = 1), + basic_conv_bn((3, 3), 2inplanes, inplanes; pad = 1)) +end - concat_layer = Chain([Parallel(cat_channels, - encoder_pool[i], - upconv[i]) - for i in 1:4]...) +function UNetFinalBlock(inplanes, outplanes) + return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), + basic_conv_bn((1, 1), inplanes, outplanes)) +end - decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]...) +function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, + m_middle = _ -> (identity,)) + isempty(layers) && return m_middle(sz[end - 1]) + + layer, layers = layers[1], layers[2:end] + outsz = Flux.outputsize(layer, sz) + does_downscale = sz[1] ÷ 2 == outsz[1] + + if !does_downscale + return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) + elseif does_downscale && skip_upscale > 0 + return Chain(layer, + unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, + outplanes)...) + else + childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) + outsz = Flux.outputsize(childunet, outsz) + + inplanes = sz[end - 1] + midplanes = outsz[end - 1] + outplanes = isnothing(outplanes) ? inplanes : outplanes + + return UNetBlock(Chain(layer, childunet), + inplanes, midplanes, outplanes) + end +end - layers = Chain(layers, decoder_layer) +function UNetBlock(m_child, inplanes, midplanes, outplanes = 2inplanes) + return Chain(; + upsample = SkipConnection(Chain(; child = m_child, + upsample = PixelShuffleICNR(midplanes, + midplanes)), + Parallel(cat_channels, identity, + BatchNorm(inplanes))), + act = xs -> relu.(xs), + combine = UNetCombineLayer(inplanes + midplanes, outplanes)) +end - decoder = Chain([Chain([ - concat_layer[i], - decoder_layer[5-i]]) - for i in 4:-1:1]...) +""" + UNet(backbone, inputsize) + DenseNet(transition_configs::NTuple{N,Integer}) - final_conv = Conv((1, 1), features => out_channels, σ) +Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as +encoder . +Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - decoder = Chain(decoder, final_conv) +!!! warning - return layers -end + `UNet` does not currently support pretrained weights. -""" - UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = inplanes) - Create a UNet model - ([reference](https://arxiv.org/abs/1505.04597v1)) - # Arguments - - `in_channels`: The number of input channels - - `inplanes`: The number of input features to the network - - `outplanes`: The number of output features -!!! warning - `UNet` does not currently support pretrained weights. +See also [`Metalhead.UNet`](@ref). """ struct UNet - layers::Any + layers::Any end @functor UNet -function UNet(in_channels::Integer = 3, inplanes::Integer = 32, - outplanes::Integer = inplanes) - layers = unet(in_channels, inplanes, outplanes) - return UNet(layers) +function UNet(backbone, + inputsize, + outplanes, + final = UNetFinalBlock, + fdownscale::Integer = 0, + kwargs...) + backbonelayers = collect(iterlayers(backbone)) + layers = unetlayers(backbonelayers, + inputsize; + m_middle = UNetMiddleBlock, + skip_upscale = fdownscale, + kwargs...) + + outsz = Flux.outputsize(layers, inputsize) + layers = Chain(layers, final(outsz[end - 1], outplanes)) + + return UNet(layers) end (m::UNet)(x::AbstractArray) = m.layers(x) diff --git a/src/utilities.jl b/src/utilities.jl index 9f54107fd..bdbe8a42e 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -4,13 +4,13 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants function _round_channels(channels::Number, divisor::Integer = 8, min_value::Integer = 0) - new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) - # Make sure that round down does not go down by more than 10% - return new_channels < 0.9 * channels ? new_channels + divisor : new_channels + new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) + # Make sure that round down does not go down by more than 10% + return new_channels < 0.9 * channels ? new_channels + divisor : new_channels end """ - addact(activation = relu, xs...) + addact(activation = relu, xs...) Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the `connection` argument for the block @@ -19,7 +19,7 @@ function in [`resnet`](@ref). addact(activation = relu, xs...) = activation(sum(xs)) """ - actadd(activation = relu, xs...) + actadd(activation = relu, xs...) Convenience function for adding input arrays after applying an activation function to them. Useful as the `connection` argument for the block function in @@ -28,7 +28,7 @@ function to them. Useful as the `connection` argument for the block function in actadd(activation = relu, xs...) = sum(activation.(x) for x in xs) """ - cat_channels(x, y, zs...) + cat_channels(x, y, zs...) Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimension). Equivalent to `cat(x, y, zs...; dims=3)`. @@ -40,7 +40,7 @@ cat_channels(x::Tuple, y::AbstractArray...) = cat_channels(x..., y...) cat_channels(x::Tuple) = cat_channels(x...) """ - swapdims(perm) + swapdims(perm) Convenience function for permuting the dimensions of an array. `perm` is a vector or tuple specifying a permutation of the input dimensions. @@ -50,20 +50,20 @@ swapdims(perm) = Base.Fix2(permutedims, perm) # Utility function for pretty printing large models function _maybe_big_show(io, model) - if isdefined(Flux, :_big_show) - if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL - Flux._big_show(io, model) - else - show(io, model) - end - else - show(io, model) - end + if isdefined(Flux, :_big_show) + if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL + Flux._big_show(io, model) + else + show(io, model) + end + else + show(io, model) + end end """ - linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) - linear_scheduler(drop_prob::Nothing; depth::Integer) + linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_prob::Nothing; depth::Integer) Returns the dropout probabilities for a given depth using the linear scaling rule. Note that this returns evenly spaced values between `start_value` and `drop_prob`, not including @@ -71,11 +71,14 @@ that this returns evenly spaced values between `start_value` and `drop_prob`, no values equal to `nothing`. """ function linear_scheduler(drop_prob = 0.0; depth::Integer, start_value = 0.0) - return LinRange(start_value, drop_prob, depth + 1)[1:depth] + return LinRange(start_value, drop_prob, depth + 1)[1:depth] end linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) - @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." + @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end + +# Utility function to return Iterator over layers, adopted from FastAI.jl +iterlayers(m::Chain) = Iterators.flatten(iterlayers(l) for l in m.layers) From fc756d9444eab84ca4b8c30b8c14c6a205fc5fb6 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:52:46 +0530 Subject: [PATCH 09/29] undoing utilities formatting --- src/utilities.jl | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/utilities.jl b/src/utilities.jl index bdbe8a42e..f52d1f97e 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -4,13 +4,13 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants function _round_channels(channels::Number, divisor::Integer = 8, min_value::Integer = 0) - new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) - # Make sure that round down does not go down by more than 10% - return new_channels < 0.9 * channels ? new_channels + divisor : new_channels + new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor) + # Make sure that round down does not go down by more than 10% + return new_channels < 0.9 * channels ? new_channels + divisor : new_channels end """ - addact(activation = relu, xs...) + addact(activation = relu, xs...) Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the `connection` argument for the block @@ -19,7 +19,7 @@ function in [`resnet`](@ref). addact(activation = relu, xs...) = activation(sum(xs)) """ - actadd(activation = relu, xs...) + actadd(activation = relu, xs...) Convenience function for adding input arrays after applying an activation function to them. Useful as the `connection` argument for the block function in @@ -28,7 +28,7 @@ function to them. Useful as the `connection` argument for the block function in actadd(activation = relu, xs...) = sum(activation.(x) for x in xs) """ - cat_channels(x, y, zs...) + cat_channels(x, y, zs...) Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimension). Equivalent to `cat(x, y, zs...; dims=3)`. @@ -40,7 +40,7 @@ cat_channels(x::Tuple, y::AbstractArray...) = cat_channels(x..., y...) cat_channels(x::Tuple) = cat_channels(x...) """ - swapdims(perm) + swapdims(perm) Convenience function for permuting the dimensions of an array. `perm` is a vector or tuple specifying a permutation of the input dimensions. @@ -50,20 +50,20 @@ swapdims(perm) = Base.Fix2(permutedims, perm) # Utility function for pretty printing large models function _maybe_big_show(io, model) - if isdefined(Flux, :_big_show) - if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL - Flux._big_show(io, model) - else - show(io, model) - end - else - show(io, model) - end + if isdefined(Flux, :_big_show) + if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL + Flux._big_show(io, model) + else + show(io, model) + end + else + show(io, model) + end end """ - linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) - linear_scheduler(drop_prob::Nothing; depth::Integer) + linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_prob::Nothing; depth::Integer) Returns the dropout probabilities for a given depth using the linear scaling rule. Note that this returns evenly spaced values between `start_value` and `drop_prob`, not including @@ -71,13 +71,13 @@ that this returns evenly spaced values between `start_value` and `drop_prob`, no values equal to `nothing`. """ function linear_scheduler(drop_prob = 0.0; depth::Integer, start_value = 0.0) - return LinRange(start_value, drop_prob, depth + 1)[1:depth] + return LinRange(start_value, drop_prob, depth + 1)[1:depth] end linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) - @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." + @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end # Utility function to return Iterator over layers, adopted from FastAI.jl From 60b082c87108c8ed4de523742a3d073ee648976d Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:12:38 +0530 Subject: [PATCH 10/29] formatting + documentation + func signature --- src/convnets/unet.jl | 155 +++++++++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 64 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index cb698f4ea..44033bce0 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,97 +1,124 @@ function PixelShuffleICNR(inplanes, outplanes; r = 2) - return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), - Flux.PixelShuffle(r)) + return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), + Flux.PixelShuffle(r)) end function UNetCombineLayer(inplanes, outplanes) - return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), - basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) + return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), + basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) end function UNetMiddleBlock(inplanes) - return Chain(basic_conv_bn((3, 3), inplanes, 2inplanes; pad = 1), - basic_conv_bn((3, 3), 2inplanes, inplanes; pad = 1)) + return Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1), + basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)) end function UNetFinalBlock(inplanes, outplanes) - return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), - basic_conv_bn((1, 1), inplanes, outplanes)) + return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), + basic_conv_bn((1, 1), inplanes, outplanes)) end function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, - m_middle = _ -> (identity,)) - isempty(layers) && return m_middle(sz[end - 1]) - - layer, layers = layers[1], layers[2:end] - outsz = Flux.outputsize(layer, sz) - does_downscale = sz[1] ÷ 2 == outsz[1] - - if !does_downscale - return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) - elseif does_downscale && skip_upscale > 0 - return Chain(layer, - unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, - outplanes)...) - else - childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) - outsz = Flux.outputsize(childunet, outsz) - - inplanes = sz[end - 1] - midplanes = outsz[end - 1] - outplanes = isnothing(outplanes) ? inplanes : outplanes - - return UNetBlock(Chain(layer, childunet), - inplanes, midplanes, outplanes) - end + m_middle = _ -> (identity,)) + isempty(layers) && return m_middle(sz[end-1]) + + layer, layers = layers[1], layers[2:end] + outsz = Flux.outputsize(layer, sz) + does_downscale = sz[1] ÷ 2 == outsz[1] + + if !does_downscale + return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) + elseif does_downscale && skip_upscale > 0 + return Chain(layer, + unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, + outplanes)...) + else + childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) + outsz = Flux.outputsize(childunet, outsz) + + inplanes = sz[end-1] + midplanes = outsz[end-1] + outplanes = isnothing(outplanes) ? inplanes : outplanes + + return UNetBlock(Chain(layer, childunet), + inplanes, midplanes, outplanes) + end +end + +function UNetBlock(m_child, inplanes, midplanes, outplanes = 2 * inplanes) + return Chain(SkipConnection(Chain(m_child, + PixelShuffleICNR(midplanes, midplanes)), + Parallel(cat_channels, identity, BatchNorm(inplanes))), + xs -> relu.(xs), + UNetCombineLayer(inplanes + midplanes, outplanes)) end -function UNetBlock(m_child, inplanes, midplanes, outplanes = 2inplanes) - return Chain(; - upsample = SkipConnection(Chain(; child = m_child, - upsample = PixelShuffleICNR(midplanes, - midplanes)), - Parallel(cat_channels, identity, - BatchNorm(inplanes))), - act = xs -> relu.(xs), - combine = UNetCombineLayer(inplanes + midplanes, outplanes)) +""" + unet(backbone::Vector{Any}; inputsize, outplanes::Integer = 3, + final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + +Creates a UNet model with specified backbone. Backbone of Any Metalhead model +can be used as encoder. +([reference](https://arxiv.org/abs/1505.04597)). + +# Arguments + + - `backbone`: The backbone layers to be used in encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed + to instantiate a UNet with layers of resnet18 as encoder. + - `inputsize`: size of input image + - `outplanes`: number of output feature planes + - `final`: final block as described in original paper + - `fdownscale`: downscale factor +""" +function unet(backbone::Vector{Any}, inputsize, outplanes::Integer, + final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + + backbonelayers = collect(iterlayers(backbone)) + layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock, + skip_upscale = fdownscale, kwargs...) + + outsz = Flux.outputsize(layers, inputsize) + layers = Chain(layers, final(outsz[end-1], outplanes)) + + return layers end """ - UNet(backbone, inputsize) - DenseNet(transition_configs::NTuple{N,Integer}) + UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, + outchannels::Integer = 3) Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as -encoder . -Set `pretrain = true` to load the model with pre-trained weights for ImageNet. +encoder. +([reference](https://arxiv.org/abs/1505.04597)). + +# Arguments + + - `backbone`: The backbone layers to be used in encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of + resnet18 as encoder. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inputsize`: size of input image + - `outchannels`: number of output channels. !!! warning - `UNet` does not currently support pretrained weights. + `UNet` does not currently support pretrained weights. See also [`Metalhead.UNet`](@ref). """ struct UNet - layers::Any + layers::Any end @functor UNet -function UNet(backbone, - inputsize, - outplanes, - final = UNetFinalBlock, - fdownscale::Integer = 0, - kwargs...) - backbonelayers = collect(iterlayers(backbone)) - layers = unetlayers(backbonelayers, - inputsize; - m_middle = UNetMiddleBlock, - skip_upscale = fdownscale, - kwargs...) - - outsz = Flux.outputsize(layers, inputsize) - layers = Chain(layers, final(outsz[end - 1], outplanes)) - - return UNet(layers) +function UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, + outchannels::Integer = 3) + layers = unet(backbone, inputsize, outchannels) + if pretrain + loadpretrain!(layers, string("UNet")) + end + return UNet(layers) end (m::UNet)(x::AbstractArray) = m.layers(x) From 2f1cc6dedbb371d9dfd8dc42593ca5405e7fddc4 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:24:59 +0530 Subject: [PATCH 11/29] adding unit tests for unet --- test/convnets.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/convnets.jl b/test/convnets.jl index af15ff413..3ca3fbcec 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -347,3 +347,12 @@ end _gc() end end + +@testset "UNet" begin + encoder = backbone(ResNet(18)) + model = UNet(encoder; inputsize = (256, 256, 3, 1), outplanes = 10) + @test size(model(x_256)) == (256, 256, 10, 1) + @test_throws ArgumentError AlexNet(pretrain = true) + @test gradtest(model, x_256) + _gc() +end From 8d2ba2b8bb81fc2580076c2b5930e24342d82261 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:31:18 +0530 Subject: [PATCH 12/29] configuring CI --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 353819673..bcaa9add0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -33,6 +33,7 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' + - '["UNet"]' - '["GoogLeNet", "SqueezeNet", "MobileNets"]' - '"EfficientNet"' - 'r"/*/ResNet*"' From 77a314805d6422fdb4ef921d70321a4c7b53e28a Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:32:06 +0530 Subject: [PATCH 13/29] configuring CI --- .github/workflows/CI.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index bcaa9add0..59ac97710 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -33,14 +33,13 @@ jobs: - x64 suite: - '["AlexNet", "VGG"]' - - '["UNet"]' - '["GoogLeNet", "SqueezeNet", "MobileNets"]' - '"EfficientNet"' - 'r"/*/ResNet*"' - 'r"/*/SEResNet*"' - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - - '"DenseNet"' + - '["UNet", "DenseNet"]' - '["ConvNeXt", "ConvMixer"]' - 'r"Mixers"' - 'r"ViTs"' From 429096bb93cf7b03b665cb722b545bebd769823c Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:55:29 +0530 Subject: [PATCH 14/29] Update convnets.jl --- test/convnets.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/convnets.jl b/test/convnets.jl index 3ca3fbcec..53c656851 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -352,7 +352,6 @@ end encoder = backbone(ResNet(18)) model = UNet(encoder; inputsize = (256, 256, 3, 1), outplanes = 10) @test size(model(x_256)) == (256, 256, 10, 1) - @test_throws ArgumentError AlexNet(pretrain = true) @test gradtest(model, x_256) _gc() end From d761126ce13dd91d0c2e659566a7ce7642d6eebf Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Tue, 10 Jan 2023 23:36:21 +0530 Subject: [PATCH 15/29] Update convnets.jl --- test/convnets.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/convnets.jl b/test/convnets.jl index 53c656851..a7051a571 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -349,7 +349,7 @@ end end @testset "UNet" begin - encoder = backbone(ResNet(18)) + encoder = Metalhead.backbone(ResNet(18)) model = UNet(encoder; inputsize = (256, 256, 3, 1), outplanes = 10) @test size(model(x_256)) == (256, 256, 10, 1) @test gradtest(model, x_256) From 1b5d2b7994b1321b85ac926acff750123f5819ec Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Wed, 11 Jan 2023 12:04:28 +0530 Subject: [PATCH 16/29] updated test --- src/convnets/unet.jl | 139 +++++++++++++++++++++---------------------- test/convnets.jl | 2 +- 2 files changed, 70 insertions(+), 71 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 44033bce0..b42aa9bbd 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,124 +1,123 @@ function PixelShuffleICNR(inplanes, outplanes; r = 2) - return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), - Flux.PixelShuffle(r)) + return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), + Flux.PixelShuffle(r)) end function UNetCombineLayer(inplanes, outplanes) - return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), - basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) + return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), + basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) end function UNetMiddleBlock(inplanes) - return Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1), - basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)) + return Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1), + basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)) end function UNetFinalBlock(inplanes, outplanes) - return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), - basic_conv_bn((1, 1), inplanes, outplanes)) + return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), + basic_conv_bn((1, 1), inplanes, outplanes)) end function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, - m_middle = _ -> (identity,)) - isempty(layers) && return m_middle(sz[end-1]) - - layer, layers = layers[1], layers[2:end] - outsz = Flux.outputsize(layer, sz) - does_downscale = sz[1] ÷ 2 == outsz[1] - - if !does_downscale - return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) - elseif does_downscale && skip_upscale > 0 - return Chain(layer, - unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, - outplanes)...) - else - childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) - outsz = Flux.outputsize(childunet, outsz) - - inplanes = sz[end-1] - midplanes = outsz[end-1] - outplanes = isnothing(outplanes) ? inplanes : outplanes - - return UNetBlock(Chain(layer, childunet), - inplanes, midplanes, outplanes) - end + m_middle = _ -> (identity,)) + isempty(layers) && return m_middle(sz[end - 1]) + + layer, layers = layers[1], layers[2:end] + outsz = Flux.outputsize(layer, sz) + does_downscale = sz[1] ÷ 2 == outsz[1] + + if !does_downscale + return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...) + elseif does_downscale && skip_upscale > 0 + return Chain(layer, + unetlayers(layers, outsz; skip_upscale = skip_upscale - 1, + outplanes)...) + else + childunet = Chain(unetlayers(layers, outsz; skip_upscale)...) + outsz = Flux.outputsize(childunet, outsz) + + inplanes = sz[end - 1] + midplanes = outsz[end - 1] + outplanes = isnothing(outplanes) ? inplanes : outplanes + + return UNetBlock(Chain(layer, childunet), + inplanes, midplanes, outplanes) + end end function UNetBlock(m_child, inplanes, midplanes, outplanes = 2 * inplanes) - return Chain(SkipConnection(Chain(m_child, - PixelShuffleICNR(midplanes, midplanes)), - Parallel(cat_channels, identity, BatchNorm(inplanes))), - xs -> relu.(xs), - UNetCombineLayer(inplanes + midplanes, outplanes)) + return Chain(SkipConnection(Chain(m_child, + PixelShuffleICNR(midplanes, midplanes)), + Parallel(cat_channels, identity, BatchNorm(inplanes))), + xs -> relu.(xs), + UNetCombineLayer(inplanes + midplanes, outplanes)) end """ - unet(backbone::Vector{Any}; inputsize, outplanes::Integer = 3, - final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + unet(backbone::Vector{Any}; inputsize, outplanes::Integer = 3, + final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) -Creates a UNet model with specified backbone. Backbone of Any Metalhead model -can be used as encoder. +Creates a UNet model with specified backbone. Backbone of Any Metalhead model +can be used as encoder. ([reference](https://arxiv.org/abs/1505.04597)). # Arguments - - `backbone`: The backbone layers to be used in encoder. - For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed - to instantiate a UNet with layers of resnet18 as encoder. - - `inputsize`: size of input image - - `outplanes`: number of output feature planes - - `final`: final block as described in original paper - - `fdownscale`: downscale factor + - `backbone`: The backbone layers to be used in encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed + to instantiate a UNet with layers of resnet18 as encoder. + - `inputsize`: size of input image + - `outplanes`: number of output feature planes + - `final`: final block as described in original paper + - `fdownscale`: downscale factor """ function unet(backbone::Vector{Any}, inputsize, outplanes::Integer, - final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + backbonelayers = collect(iterlayers(backbone)) + layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock, + skip_upscale = fdownscale, kwargs...) - backbonelayers = collect(iterlayers(backbone)) - layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock, - skip_upscale = fdownscale, kwargs...) + outsz = Flux.outputsize(layers, inputsize) + layers = Chain(layers, final(outsz[end - 1], outplanes)) - outsz = Flux.outputsize(layers, inputsize) - layers = Chain(layers, final(outsz[end-1], outplanes)) - - return layers + return layers end """ - UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, - outchannels::Integer = 3) + UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, + outchannels::Integer = 3) Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as -encoder. +encoder. ([reference](https://arxiv.org/abs/1505.04597)). # Arguments - - `backbone`: The backbone layers to be used in encoder. - For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of - resnet18 as encoder. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - - `inputsize`: size of input image - - `outchannels`: number of output channels. + - `backbone`: The backbone layers to be used in encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of + resnet18 as encoder. + - `pretrain`: Whether to load the pre-trained weights for ImageNet + - `inputsize`: size of input image + - `outchannels`: number of output channels. !!! warning - `UNet` does not currently support pretrained weights. + `UNet` does not currently support pretrained weights. See also [`Metalhead.UNet`](@ref). """ struct UNet - layers::Any + layers::Any end @functor UNet function UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, - outchannels::Integer = 3) - layers = unet(backbone, inputsize, outchannels) + outchannels::Integer = 3) + layers = unet(backbone, inputsize, outchannels) if pretrain loadpretrain!(layers, string("UNet")) end - return UNet(layers) + return UNet(layers) end (m::UNet)(x::AbstractArray) = m.layers(x) diff --git a/test/convnets.jl b/test/convnets.jl index a7051a571..81da6517f 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -350,7 +350,7 @@ end @testset "UNet" begin encoder = Metalhead.backbone(ResNet(18)) - model = UNet(encoder; inputsize = (256, 256, 3, 1), outplanes = 10) + model = UNet(encoder; inputsize = (256, 256, 3, 1), outchannels = 10) @test size(model(x_256)) == (256, 256, 10, 1) @test gradtest(model, x_256) _gc() From 354e3c457497347425e2ce9ef4884fabc1719afd Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:28:57 +0530 Subject: [PATCH 17/29] minor fixes --- src/convnets/unet.jl | 6 +++--- test/convnets.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index b42aa9bbd..ef7abc1ab 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -111,9 +111,9 @@ struct UNet end @functor UNet -function UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, - outchannels::Integer = 3) - layers = unet(backbone, inputsize, outchannels) +function UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3; + pretrain::Bool = false) + layers = unet(backbone, inputsize, outplanes) if pretrain loadpretrain!(layers, string("UNet")) end diff --git a/test/convnets.jl b/test/convnets.jl index 81da6517f..d29bccc8f 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -350,7 +350,7 @@ end @testset "UNet" begin encoder = Metalhead.backbone(ResNet(18)) - model = UNet(encoder; inputsize = (256, 256, 3, 1), outchannels = 10) + model = UNet(encoder, (256, 256, 3, 1), 10) @test size(model(x_256)) == (256, 256, 10, 1) @test gradtest(model, x_256) _gc() From 6494be71bdeaf9e2fb9c0d07fe4de30454117fb8 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Thu, 12 Jan 2023 11:17:41 +0530 Subject: [PATCH 18/29] typing fix --- src/convnets/unet.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index ef7abc1ab..2e77404af 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -54,7 +54,7 @@ function UNetBlock(m_child, inplanes, midplanes, outplanes = 2 * inplanes) end """ - unet(backbone::Vector{Any}; inputsize, outplanes::Integer = 3, + unet(backbone; inputsize::NTuple{4, Integer}, outplanes::Integer = 3, final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) Creates a UNet model with specified backbone. Backbone of Any Metalhead model @@ -71,7 +71,7 @@ can be used as encoder. - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function unet(backbone::Vector{Any}, inputsize, outplanes::Integer, +function unet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer, final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) backbonelayers = collect(iterlayers(backbone)) layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock, @@ -84,8 +84,8 @@ function unet(backbone::Vector{Any}, inputsize, outplanes::Integer, end """ - UNet(backbone::Vector{Any}; pretrain::Bool = false, inputsize::NTuple{4, Integer}, - outchannels::Integer = 3) + UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3; + pretrain::Bool = false) Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as encoder. @@ -96,9 +96,9 @@ encoder. - `backbone`: The backbone layers to be used in encoder. For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of resnet18 as encoder. - - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inputsize`: size of input image - - `outchannels`: number of output channels. + - `outplanes`: number of output feature planes. + - `pretrain`: Whether to load the pre-trained weights for ImageNet !!! warning From 2d68f61c8574cb80e9b661cd3ac64f0ad47961dc Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Thu, 12 Jan 2023 12:52:59 +0530 Subject: [PATCH 19/29] Update src/utilities.jl Co-authored-by: Priya Nagda <64613009+pri1311@users.noreply.github.com> --- src/utilities.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utilities.jl b/src/utilities.jl index f52d1f97e..63b33ae8d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -82,3 +82,4 @@ end # Utility function to return Iterator over layers, adopted from FastAI.jl iterlayers(m::Chain) = Iterators.flatten(iterlayers(l) for l in m.layers) +iterlayers(m) = (m,) From 627480f8ab76f309e38bd86d3277657a6d42d8cb Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Thu, 12 Jan 2023 17:17:06 +0530 Subject: [PATCH 20/29] fixing ci --- src/convnets/unet.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 2e77404af..bdc5ec952 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,21 +1,21 @@ function PixelShuffleICNR(inplanes, outplanes; r = 2) - return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)), + return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)), Flux.PixelShuffle(r)) end function UNetCombineLayer(inplanes, outplanes) - return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1), - basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)) + return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...), + Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...)) end function UNetMiddleBlock(inplanes) - return Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1), - basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)) + return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...), + Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...)) end function UNetFinalBlock(inplanes, outplanes) return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), - basic_conv_bn((1, 1), inplanes, outplanes)) + Chain(basic_conv_bn((1, 1), inplanes, outplanes)...)) end function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, From 4012fb2847a6689a24438fc80fc1e043177892e4 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Mon, 16 Jan 2023 23:43:56 +0530 Subject: [PATCH 21/29] renaming: --- src/convnets/unet.jl | 47 ++++++++++++++++++++++---------------------- src/utilities.jl | 18 ++++++++++++++--- 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index bdc5ec952..ef9965f83 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -1,19 +1,19 @@ -function PixelShuffleICNR(inplanes, outplanes; r = 2) +function pixel_shuffle_icnr(inplanes, outplanes; r = 2) return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)), Flux.PixelShuffle(r)) end -function UNetCombineLayer(inplanes, outplanes) +function unet_combine_layer(inplanes, outplanes) return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...), Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...)) end -function UNetMiddleBlock(inplanes) +function unet_middle_block(inplanes) return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...), Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...)) end -function UNetFinalBlock(inplanes, outplanes) +function unet_final_block(inplanes, outplanes) return Chain(basicblock(inplanes, inplanes; reduction_factor = 1), Chain(basic_conv_bn((1, 1), inplanes, outplanes)...)) end @@ -40,22 +40,22 @@ function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, midplanes = outsz[end - 1] outplanes = isnothing(outplanes) ? inplanes : outplanes - return UNetBlock(Chain(layer, childunet), - inplanes, midplanes, outplanes) + return unet_block(Chain(layer, childunet), + inplanes, midplanes, outplanes) end end -function UNetBlock(m_child, inplanes, midplanes, outplanes = 2 * inplanes) +function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes) return Chain(SkipConnection(Chain(m_child, - PixelShuffleICNR(midplanes, midplanes)), + pixel_shuffle_icnr(midplanes, midplanes)), Parallel(cat_channels, identity, BatchNorm(inplanes))), - xs -> relu.(xs), - UNetCombineLayer(inplanes + midplanes, outplanes)) + relu, + unet_combine_layer(inplanes + midplanes, outplanes)) end """ - unet(backbone; inputsize::NTuple{4, Integer}, outplanes::Integer = 3, - final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) + build_unet(backbone, imgdims, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as encoder. @@ -71,21 +71,21 @@ can be used as encoder. - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function unet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer, - final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...) - backbonelayers = collect(iterlayers(backbone)) - layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock, +function build_unet(backbone, imgdims, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) + backbonelayers = collect(flatten_chains(backbone)) + layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, skip_upscale = fdownscale, kwargs...) - outsz = Flux.outputsize(layers, inputsize) + outsz = Flux.outputsize(layers, imgdims) layers = Chain(layers, final(outsz[end - 1], outplanes)) return layers end """ - UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3; - pretrain::Bool = false) +UNet(backbone = Metalhead.backbone(DenseNet(122)), imsize::Dims{2} = (256, 256), +inchannels::Integer = 3, outplanes::Integer = 3; pretrain::Bool = false) Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as encoder. @@ -96,7 +96,8 @@ encoder. - `backbone`: The backbone layers to be used in encoder. For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of resnet18 as encoder. - - `inputsize`: size of input image + - `imsize`: size of input image + - `inchannels`: number of channels in input image - `outplanes`: number of output feature planes. - `pretrain`: Whether to load the pre-trained weights for ImageNet @@ -111,9 +112,9 @@ struct UNet end @functor UNet -function UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3; - pretrain::Bool = false) - layers = unet(backbone, inputsize, outplanes) +function UNet(backbone = Metalhead.backbone(DenseNet(122)), imsize::Dims{2} = (256, 256), + inchannels::Integer = 3, outplanes::Integer = 3; pretrain::Bool = false) + layers = build_unet(backbone, (imsize..., inchannels, 1), outplanes) if pretrain loadpretrain!(layers, string("UNet")) end diff --git a/src/utilities.jl b/src/utilities.jl index 63b33ae8d..023cb699b 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -80,6 +80,18 @@ function _checkconfig(config, configs) @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end -# Utility function to return Iterator over layers, adopted from FastAI.jl -iterlayers(m::Chain) = Iterators.flatten(iterlayers(l) for l in m.layers) -iterlayers(m) = (m,) + +# adopted from FastAI.jl +""" +The flatten_chains function takes in a single argument m, which can be of `Any` type. +If the input m is of type `Chain`, the function returns an iterator that recursively flattens +the layers of the input Chain object and all its sub-layers. +It does this by calling the flatten_chains function on each element of m.layers and +passing it into the Iterators.flatten function, which flatten the layers into a single iterator. +If the input m is not of type Chain, the function returns a tuple containing the single input m. +This function can be useful when you want to traverse nested layers of a Chain object and flatten +them into a single iterator. +""" +flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers) +flatten_chains(m) = (m,) + From 016cef4fc13cb47e6276ba53c22b341e76886043 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 12:57:07 +0530 Subject: [PATCH 22/29] fixing test --- src/convnets/unet.jl | 35 ++-- test/convnets.jl | 380 +++++++++++++++++++++---------------------- 2 files changed, 204 insertions(+), 211 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index ef9965f83..f582d6a0a 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -18,6 +18,14 @@ function unet_final_block(inplanes, outplanes) Chain(basic_conv_bn((1, 1), inplanes, outplanes)...)) end +function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes) + return Chain(SkipConnection(Chain(m_child, + pixel_shuffle_icnr(midplanes, midplanes)), + Parallel(cat_channels, identity, BatchNorm(inplanes))), + relu, + unet_combine_layer(inplanes + midplanes, outplanes)) +end + function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, m_middle = _ -> (identity,)) isempty(layers) && return m_middle(sz[end - 1]) @@ -45,14 +53,6 @@ function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, end end -function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes) - return Chain(SkipConnection(Chain(m_child, - pixel_shuffle_icnr(midplanes, midplanes)), - Parallel(cat_channels, identity, BatchNorm(inplanes))), - relu, - unet_combine_layer(inplanes + midplanes, outplanes)) -end - """ build_unet(backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) @@ -84,21 +84,21 @@ function build_unet(backbone, imgdims, outplanes::Integer, end """ -UNet(backbone = Metalhead.backbone(DenseNet(122)), imsize::Dims{2} = (256, 256), -inchannels::Integer = 3, outplanes::Integer = 3; pretrain::Bool = false) +UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, +backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) -Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as -encoder. +Creates a UNet model with an encoder built of specified backbone. By default it uses DenseNet Backbone, however any of +Metalhead model can be used for encoder backbone ([reference](https://arxiv.org/abs/1505.04597)). # Arguments - - `backbone`: The backbone layers to be used in encoder. - For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of - resnet18 as encoder. - `imsize`: size of input image - `inchannels`: number of channels in input image - `outplanes`: number of output feature planes. + - `backbone`: The backbone layers to be used in encoder. + For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of + resnet18 as encoder. - `pretrain`: Whether to load the pre-trained weights for ImageNet !!! warning @@ -112,9 +112,10 @@ struct UNet end @functor UNet -function UNet(backbone = Metalhead.backbone(DenseNet(122)), imsize::Dims{2} = (256, 256), - inchannels::Integer = 3, outplanes::Integer = 3; pretrain::Bool = false) +function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, + backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) layers = build_unet(backbone, (imsize..., inchannels, 1), outplanes) + if pretrain loadpretrain!(layers, string("UNet")) end diff --git a/test/convnets.jl b/test/convnets.jl index d29bccc8f..917b4beaf 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -6,19 +6,19 @@ _gc() end -@testset "VGG" begin - @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] - m = VGG(sz, batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (VGG, sz, bn) in PRETRAINED_MODELS - @test acctest(VGG(sz, batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], + bn in [true, false] + + m = VGG(sz; batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (VGG, sz, bn) in PRETRAINED_MODELS + @test acctest(VGG(sz; batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "ResNet" begin # Tests for pretrained ResNets @@ -26,100 +26,94 @@ end m = ResNet(sz) @test size(m(x_224)) == (1000, 1) if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) + @test acctest(ResNet(sz; pretrain = true)) else @test_throws ArgumentError ResNet(sz, pretrain = true) end end - @testset "resnet" begin - @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] - layer_list = [ - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3] + @testset "resnet" begin @testset for block_fn in [ + Metalhead.basicblock, + Metalhead.bottleneck, + ] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3], + ] + @testset for layers in layer_list + drop_list = [ + (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), + (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), ] - @testset for layers in layer_list - drop_list = [ - (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), - (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), - (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), - ] - @testset for drop_probs in drop_list - m = Metalhead.resnet(block_fn, layers; drop_probs...) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - end + @testset for drop_probs in drop_list + m = Metalhead.resnet(block_fn, layers; drop_probs...) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() end end - end + end end - @testset "WideResNet" begin - @testset "WideResNet($sz)" for sz in [50, 101] - m = WideResNet(sz) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - if (WideResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz, pretrain = true)) - else - @test_throws ArgumentError WideResNet(sz, pretrain = true) - end + @testset "WideResNet" begin @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz; pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) end - end + end end end -@testset "ResNeXt" begin - @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = ResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "ResNeXt" begin @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, + pretrain = true) end + @test gradtest(m, x_224) + _gc() end end -end +end end -@testset "SEResNet" begin - @testset for depth in [18, 34, 50, 101, 152] - m = SEResNet(depth) - @test size(m(x_224)) == (1000, 1) - if (SEResNet, depth) in PRETRAINED_MODELS - @test acctest(SEResNet(depth, pretrain = true)) - else - @test_throws ArgumentError SEResNet(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "SEResNet" begin @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if (SEResNet, depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth; pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "SEResNeXt" begin - @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = SEResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(SEResNeXt(depth, pretrain = true)) - else - @test_throws ArgumentError SEResNeXt(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "SEResNeXt" begin @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth; pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) end + @test gradtest(m, x_224) + _gc() end end -end +end end @testset "Res2Net" begin @testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)] @@ -146,64 +140,63 @@ end end end -@testset "Res2NeXt" begin - @testset for depth in [50, 101] - m = Res2NeXt(depth) - @test size(m(x_224)) == (1000, 1) - if (Res2NeXt, depth) in PRETRAINED_MODELS - @test acctest(Res2NeXt(depth, pretrain = true)) - else - @test_throws ArgumentError Res2NeXt(depth, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "Res2NeXt" begin @testset for depth in [50, 101] + m = Res2NeXt(depth) + @test size(m(x_224)) == (1000, 1) + if (Res2NeXt, depth) in PRETRAINED_MODELS + @test acctest(Res2NeXt(depth; pretrain = true)) + else + @test_throws ArgumentError Res2NeXt(depth, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "EfficientNet" begin - @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] - # preferred image resolution scaling - r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] - x = rand(Float32, r, r, 3, 1) - m = EfficientNet(config) - @test size(m(x)) == (1000, 1) - if (EfficientNet, config) in PRETRAINED_MODELS - @test acctest(EfficientNet(config, pretrain = true)) - else - @test_throws ArgumentError EfficientNet(config, pretrain = true) - end - @test gradtest(m, x) - _gc() +@testset "EfficientNet" begin @testset "EfficientNet($config)" for config in [ + :b0, + :b1, + :b2, + :b3, + :b4, + :b5, +] #:b6, :b7, :b8] + # preferred image resolution scaling + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] + x = rand(Float32, r, r, 3, 1) + m = EfficientNet(config) + @test size(m(x)) == (1000, 1) + if (EfficientNet, config) in PRETRAINED_MODELS + @test acctest(EfficientNet(config; pretrain = true)) + else + @test_throws ArgumentError EfficientNet(config, pretrain = true) end -end + @test gradtest(m, x) + _gc() +end end -@testset "EfficientNetv2" begin - @testset for config in [:small, :medium, :large] # :xlarge] - m = EfficientNetv2(config) - @test size(m(x_224)) == (1000, 1) - if (EfficientNetv2, config) in PRETRAINED_MODELS - @test acctest(EfficientNetv2(config, pretrain = true)) - else - @test_throws ArgumentError EfficientNetv2(config, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "EfficientNetv2" begin @testset for config in [:small, :medium, :large] # :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config; pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end -@testset "GoogLeNet" begin - @testset for bn in [true, false] - m = GoogLeNet(batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (GoogLeNet, bn) in PRETRAINED_MODELS - @test acctest(GoogLeNet(batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "GoogLeNet" begin @testset for bn in [true, false] + m = GoogLeNet(; batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (GoogLeNet, bn) in PRETRAINED_MODELS + @test acctest(GoogLeNet(; batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @@ -211,7 +204,7 @@ end m = Inceptionv3() @test size(m(x_299)) == (1000, 2) if Inceptionv3 in PRETRAINED_MODELS - @test acctest(Inceptionv3(pretrain = true)) + @test acctest(Inceptionv3(; pretrain = true)) else @test_throws ArgumentError Inceptionv3(pretrain = true) end @@ -222,7 +215,7 @@ end m = Inceptionv4() @test size(m(x_299)) == (1000, 2) if Inceptionv4 in PRETRAINED_MODELS - @test acctest(Inceptionv4(pretrain = true)) + @test acctest(Inceptionv4(; pretrain = true)) else @test_throws ArgumentError Inceptionv4(pretrain = true) end @@ -233,7 +226,7 @@ end m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) if InceptionResNetv2 in PRETRAINED_MODELS - @test acctest(InceptionResNetv2(pretrain = true)) + @test acctest(InceptionResNetv2(; pretrain = true)) else @test_throws ArgumentError InceptionResNetv2(pretrain = true) end @@ -244,7 +237,7 @@ end m = Xception() @test size(m(x_299)) == (1000, 2) if Xception in PRETRAINED_MODELS - @test acctest(Xception(pretrain = true)) + @test acctest(Xception(; pretrain = true)) else @test_throws ArgumentError Xception(pretrain = true) end @@ -257,7 +250,7 @@ end m = SqueezeNet() @test size(m(x_224)) == (1000, 1) if SqueezeNet in PRETRAINED_MODELS - @test acctest(SqueezeNet(pretrain = true)) + @test acctest(SqueezeNet(; pretrain = true)) else @test_throws ArgumentError SqueezeNet(pretrain = true) end @@ -265,26 +258,24 @@ end _gc() end -@testset "DenseNet" begin - @testset for sz in [121, 161, 169, 201] - m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) - if (DenseNet, sz) in PRETRAINED_MODELS - @test acctest(DenseNet(sz, pretrain = true)) - else - @test_throws ArgumentError DenseNet(sz, pretrain = true) - end - @test gradtest(m, x_224) - _gc() +@testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] + m = DenseNet(sz) + @test size(m(x_224)) == (1000, 1) + if (DenseNet, sz) in PRETRAINED_MODELS + @test acctest(DenseNet(sz; pretrain = true)) + else + @test_throws ArgumentError DenseNet(sz, pretrain = true) end -end + @test gradtest(m, x_224) + _gc() +end end @testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] @testset "MobileNetv1" begin m = MobileNetv1(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv1, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv1(pretrain = true)) + @test acctest(MobileNetv1(; pretrain = true)) else @test_throws ArgumentError MobileNetv1(pretrain = true) end @@ -295,63 +286,64 @@ end m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(pretrain = true)) + @test acctest(MobileNetv2(; pretrain = true)) else @test_throws ArgumentError MobileNetv2(pretrain = true) end @test gradtest(m, x_224) end _gc() - @testset "MobileNetv3" verbose = true begin - @testset for config in [:small, :large] - m = MobileNetv3(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) - else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end - end - @testset "MNASNet" verbose = true begin - @testset for config in [:A1, :B1] - m = MNASNet(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MNASNet, config, width_mult) in PRETRAINED_MODELS - @test acctest(MNASNet(config; pretrain = true)) - else - @test_throws ArgumentError MNASNet(config; pretrain = true) - end - @test gradtest(m, x_224) - _gc() - end - end -end - -@testset "ConvNeXt" verbose = true begin - @testset for config in [:small, :base, :large, :tiny, :xlarge] - m = ConvNeXt(config) + @testset "MobileNetv3" verbose=true begin @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) @test size(m(x_224)) == (1000, 1) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) + else + @test_throws ArgumentError MobileNetv3(config; pretrain = true) + end @test gradtest(m, x_224) _gc() - end -end - -@testset "ConvMixer" verbose = true begin - @testset for config in [:small, :base, :large] - m = ConvMixer(config) + end end + @testset "MNASNet" verbose=true begin @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end @test gradtest(m, x_224) _gc() - end + end end end +@testset "ConvNeXt" verbose=true begin @testset for config in [ + :small, + :base, + :large, + :tiny, + :xlarge, +] + m = ConvNeXt(config) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() +end end + +@testset "ConvMixer" verbose=true begin @testset for config in [:small, :base, :large] + m = ConvMixer(config) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() +end end + @testset "UNet" begin encoder = Metalhead.backbone(ResNet(18)) - model = UNet(encoder, (256, 256, 3, 1), 10) + model = UNet((256, 256), 3, 10, encoder) @test size(model(x_256)) == (256, 256, 10, 1) @test gradtest(model, x_256) + + model = UNet() + @test size(model(x_256)) == (256, 256, 3, 1) _gc() end From 6097c57db4b899a6b1998035bc9f751c23c723e0 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 21:01:17 +0530 Subject: [PATCH 23/29] Update .github/workflows/CI.yml Co-authored-by: Kyle Daruwalla --- .github/workflows/CI.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 59ac97710..b41c1f867 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -39,7 +39,8 @@ jobs: - 'r"/*/SEResNet*"' - '[r"Res2Net", r"Res2NeXt"]' - '"Inception"' - - '["UNet", "DenseNet"]' + - '"DenseNet"' + - '"UNet"' - '["ConvNeXt", "ConvMixer"]' - 'r"Mixers"' - 'r"ViTs"' From 98b4c30fd50b6d9fbe81e75771766ed4f8dc52b4 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 21:12:36 +0530 Subject: [PATCH 24/29] Update src/convnets/unet.jl Co-authored-by: Kyle Daruwalla --- src/convnets/unet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index f582d6a0a..aae86bd8e 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -57,8 +57,8 @@ end build_unet(backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) -Creates a UNet model with specified backbone. Backbone of Any Metalhead model -can be used as encoder. +Creates a UNet model with specified convolutional backbone. +Backbone of any Metalhead ResNet-like model can be used as encoder ([reference](https://arxiv.org/abs/1505.04597)). # Arguments From 54c334f9a540a24aee19dfced26fcb3c2e497040 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 21:13:06 +0530 Subject: [PATCH 25/29] Update src/convnets/unet.jl Co-authored-by: Kyle Daruwalla --- src/convnets/unet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index aae86bd8e..87615251c 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -87,8 +87,8 @@ end UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) -Creates a UNet model with an encoder built of specified backbone. By default it uses DenseNet Backbone, however any of -Metalhead model can be used for encoder backbone +Creates a UNet model with an encoder built of specified backbone. +By default it uses [`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder ([reference](https://arxiv.org/abs/1505.04597)). # Arguments From 4fae8d639a6d6f801e09ce4f2f4be5f2d8184480 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 22:24:42 +0530 Subject: [PATCH 26/29] incorporating suggestions --- src/convnets/unet.jl | 24 +-- src/utilities.jl | 15 +- test/convnets.jl | 375 ++++++++++++++++++++++--------------------- 3 files changed, 210 insertions(+), 204 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 87615251c..333577b07 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -54,8 +54,8 @@ function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0, end """ - build_unet(backbone, imgdims, outplanes::Integer, - final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) + unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, + fdownscale::Integer = 0, kwargs...) Creates a UNet model with specified convolutional backbone. Backbone of any Metalhead ResNet-like model can be used as encoder @@ -63,7 +63,7 @@ Backbone of any Metalhead ResNet-like model can be used as encoder # Arguments - - `backbone`: The backbone layers to be used in encoder. + - `encoder_backbone`: The backbone layers of specified model to be used as encoder. For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of resnet18 as encoder. - `inputsize`: size of input image @@ -71,11 +71,11 @@ Backbone of any Metalhead ResNet-like model can be used as encoder - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function build_unet(backbone, imgdims, outplanes::Integer, - final::Any = unet_final_block, fdownscale::Integer = 0, kwargs...) - backbonelayers = collect(flatten_chains(backbone)) +function unet(encoder_backbone, imgdims, outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0) + backbonelayers = collect(flatten_chains(encoder_backbone)) layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, - skip_upscale = fdownscale, kwargs...) + skip_upscale = fdownscale) outsz = Flux.outputsize(layers, imgdims) layers = Chain(layers, final(outsz[end - 1], outplanes)) @@ -84,8 +84,8 @@ function build_unet(backbone, imgdims, outplanes::Integer, end """ -UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, -backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) + UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, + encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) Creates a UNet model with an encoder built of specified backbone. By default it uses [`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder @@ -96,7 +96,7 @@ By default it uses [`DenseNet`](@ref) backbone, but any ResNet-like Metalhead mo - `imsize`: size of input image - `inchannels`: number of channels in input image - `outplanes`: number of output feature planes. - - `backbone`: The backbone layers to be used in encoder. + - `encoder_backbone`: The backbone layers of specified model to be used as encoder. For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of resnet18 as encoder. - `pretrain`: Whether to load the pre-trained weights for ImageNet @@ -113,8 +113,8 @@ end @functor UNet function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, - backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) - layers = build_unet(backbone, (imsize..., inchannels, 1), outplanes) + encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) + layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) if pretrain loadpretrain!(layers, string("UNet")) diff --git a/src/utilities.jl b/src/utilities.jl index 023cb699b..316d884c6 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -80,17 +80,12 @@ function _checkconfig(config, configs) @assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))." end - -# adopted from FastAI.jl """ -The flatten_chains function takes in a single argument m, which can be of `Any` type. -If the input m is of type `Chain`, the function returns an iterator that recursively flattens -the layers of the input Chain object and all its sub-layers. -It does this by calling the flatten_chains function on each element of m.layers and -passing it into the Iterators.flatten function, which flatten the layers into a single iterator. -If the input m is not of type Chain, the function returns a tuple containing the single input m. -This function can be useful when you want to traverse nested layers of a Chain object and flatten -them into a single iterator. + flatten_chains(m::Chain) + flatten_chains(m) + +Convenience function for traversing nested layers of a Chain object and flatten them +into a single iterator. """ flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers) flatten_chains(m) = (m,) diff --git a/test/convnets.jl b/test/convnets.jl index 917b4beaf..dede5bdb3 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -6,19 +6,19 @@ _gc() end -@testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], - bn in [true, false] - - m = VGG(sz; batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (VGG, sz, bn) in PRETRAINED_MODELS - @test acctest(VGG(sz; batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) +@testset "VGG" begin + @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] + m = VGG(sz, batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (VGG, sz, bn) in PRETRAINED_MODELS + @test acctest(VGG(sz, batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "ResNet" begin # Tests for pretrained ResNets @@ -26,94 +26,100 @@ end end m = ResNet(sz) @test size(m(x_224)) == (1000, 1) if (ResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz; pretrain = true)) + @test acctest(ResNet(sz, pretrain = true)) else @test_throws ArgumentError ResNet(sz, pretrain = true) end end - @testset "resnet" begin @testset for block_fn in [ - Metalhead.basicblock, - Metalhead.bottleneck, - ] - layer_list = [ - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3], - ] - @testset for layers in layer_list - drop_list = [ - (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), - (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), - (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), + @testset "resnet" begin + @testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck] + layer_list = [ + [2, 2, 2, 2], + [3, 4, 6, 3], + [3, 4, 23, 3], + [3, 8, 36, 3] ] - @testset for drop_probs in drop_list - m = Metalhead.resnet(block_fn, layers; drop_probs...) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() + @testset for layers in layer_list + drop_list = [ + (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), + (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), + ] + @testset for drop_probs in drop_list + m = Metalhead.resnet(block_fn, layers; drop_probs...) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + end end end - end end + end - @testset "WideResNet" begin @testset "WideResNet($sz)" for sz in [50, 101] - m = WideResNet(sz) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() - if (WideResNet, sz) in PRETRAINED_MODELS - @test acctest(ResNet(sz; pretrain = true)) - else - @test_throws ArgumentError WideResNet(sz, pretrain = true) + @testset "WideResNet" begin + @testset "WideResNet($sz)" for sz in [50, 101] + m = WideResNet(sz) + @test size(m(x_224)) == (1000, 1) + @test gradtest(m, x_224) + _gc() + if (WideResNet, sz) in PRETRAINED_MODELS + @test acctest(ResNet(sz, pretrain = true)) + else + @test_throws ArgumentError WideResNet(sz, pretrain = true) + end end - end end + end end -@testset "ResNeXt" begin @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = ResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) - else - @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, - pretrain = true) +@testset "ResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = ResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true)) + else + @test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() end end -end end +end -@testset "SEResNet" begin @testset for depth in [18, 34, 50, 101, 152] - m = SEResNet(depth) - @test size(m(x_224)) == (1000, 1) - if (SEResNet, depth) in PRETRAINED_MODELS - @test acctest(SEResNet(depth; pretrain = true)) - else - @test_throws ArgumentError SEResNet(depth, pretrain = true) +@testset "SEResNet" begin + @testset for depth in [18, 34, 50, 101, 152] + m = SEResNet(depth) + @test size(m(x_224)) == (1000, 1) + if (SEResNet, depth) in PRETRAINED_MODELS + @test acctest(SEResNet(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNet(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "SEResNeXt" begin @testset for depth in [50, 101, 152] - @testset for cardinality in [32, 64] - @testset for base_width in [4, 8] - m = SEResNeXt(depth; cardinality, base_width) - @test size(m(x_224)) == (1000, 1) - if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS - @test acctest(SEResNeXt(depth; pretrain = true)) - else - @test_throws ArgumentError SEResNeXt(depth, pretrain = true) +@testset "SEResNeXt" begin + @testset for depth in [50, 101, 152] + @testset for cardinality in [32, 64] + @testset for base_width in [4, 8] + m = SEResNeXt(depth; cardinality, base_width) + @test size(m(x_224)) == (1000, 1) + if (SEResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS + @test acctest(SEResNeXt(depth, pretrain = true)) + else + @test_throws ArgumentError SEResNeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() end end -end end +end @testset "Res2Net" begin @testset for (base_width, scale) in [(26, 4), (48, 2), (14, 8), (26, 6), (26, 8)] @@ -140,63 +146,64 @@ end end end end -@testset "Res2NeXt" begin @testset for depth in [50, 101] - m = Res2NeXt(depth) - @test size(m(x_224)) == (1000, 1) - if (Res2NeXt, depth) in PRETRAINED_MODELS - @test acctest(Res2NeXt(depth; pretrain = true)) - else - @test_throws ArgumentError Res2NeXt(depth, pretrain = true) +@testset "Res2NeXt" begin + @testset for depth in [50, 101] + m = Res2NeXt(depth) + @test size(m(x_224)) == (1000, 1) + if (Res2NeXt, depth) in PRETRAINED_MODELS + @test acctest(Res2NeXt(depth, pretrain = true)) + else + @test_throws ArgumentError Res2NeXt(depth, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "EfficientNet" begin @testset "EfficientNet($config)" for config in [ - :b0, - :b1, - :b2, - :b3, - :b4, - :b5, -] #:b6, :b7, :b8] - # preferred image resolution scaling - r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] - x = rand(Float32, r, r, 3, 1) - m = EfficientNet(config) - @test size(m(x)) == (1000, 1) - if (EfficientNet, config) in PRETRAINED_MODELS - @test acctest(EfficientNet(config; pretrain = true)) - else - @test_throws ArgumentError EfficientNet(config, pretrain = true) +@testset "EfficientNet" begin + @testset "EfficientNet($config)" for config in [:b0, :b1, :b2, :b3, :b4, :b5,] #:b6, :b7, :b8] + # preferred image resolution scaling + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[config][1] + x = rand(Float32, r, r, 3, 1) + m = EfficientNet(config) + @test size(m(x)) == (1000, 1) + if (EfficientNet, config) in PRETRAINED_MODELS + @test acctest(EfficientNet(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNet(config, pretrain = true) + end + @test gradtest(m, x) + _gc() end - @test gradtest(m, x) - _gc() -end end +end -@testset "EfficientNetv2" begin @testset for config in [:small, :medium, :large] # :xlarge] - m = EfficientNetv2(config) - @test size(m(x_224)) == (1000, 1) - if (EfficientNetv2, config) in PRETRAINED_MODELS - @test acctest(EfficientNetv2(config; pretrain = true)) - else - @test_throws ArgumentError EfficientNetv2(config, pretrain = true) +@testset "EfficientNetv2" begin + @testset for config in [:small, :medium, :large] # :xlarge] + m = EfficientNetv2(config) + @test size(m(x_224)) == (1000, 1) + if (EfficientNetv2, config) in PRETRAINED_MODELS + @test acctest(EfficientNetv2(config, pretrain = true)) + else + @test_throws ArgumentError EfficientNetv2(config, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end -@testset "GoogLeNet" begin @testset for bn in [true, false] - m = GoogLeNet(; batchnorm = bn) - @test size(m(x_224)) == (1000, 1) - if (GoogLeNet, bn) in PRETRAINED_MODELS - @test acctest(GoogLeNet(; batchnorm = bn, pretrain = true)) - else - @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) +@testset "GoogLeNet" begin + @testset for bn in [true, false] + m = GoogLeNet(batchnorm = bn) + @test size(m(x_224)) == (1000, 1) + if (GoogLeNet, bn) in PRETRAINED_MODELS + @test acctest(GoogLeNet(batchnorm = bn, pretrain = true)) + else + @test_throws ArgumentError GoogLeNet(batchnorm = bn, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "Inception" begin x_299 = rand(Float32, 299, 299, 3, 2) @@ -204,7 +211,7 @@ end end m = Inceptionv3() @test size(m(x_299)) == (1000, 2) if Inceptionv3 in PRETRAINED_MODELS - @test acctest(Inceptionv3(; pretrain = true)) + @test acctest(Inceptionv3(pretrain = true)) else @test_throws ArgumentError Inceptionv3(pretrain = true) end @@ -215,7 +222,7 @@ end end m = Inceptionv4() @test size(m(x_299)) == (1000, 2) if Inceptionv4 in PRETRAINED_MODELS - @test acctest(Inceptionv4(; pretrain = true)) + @test acctest(Inceptionv4(pretrain = true)) else @test_throws ArgumentError Inceptionv4(pretrain = true) end @@ -226,7 +233,7 @@ end end m = InceptionResNetv2() @test size(m(x_299)) == (1000, 2) if InceptionResNetv2 in PRETRAINED_MODELS - @test acctest(InceptionResNetv2(; pretrain = true)) + @test acctest(InceptionResNetv2(pretrain = true)) else @test_throws ArgumentError InceptionResNetv2(pretrain = true) end @@ -237,7 +244,7 @@ end end m = Xception() @test size(m(x_299)) == (1000, 2) if Xception in PRETRAINED_MODELS - @test acctest(Xception(; pretrain = true)) + @test acctest(Xception(pretrain = true)) else @test_throws ArgumentError Xception(pretrain = true) end @@ -250,7 +257,7 @@ end m = SqueezeNet() @test size(m(x_224)) == (1000, 1) if SqueezeNet in PRETRAINED_MODELS - @test acctest(SqueezeNet(; pretrain = true)) + @test acctest(SqueezeNet(pretrain = true)) else @test_throws ArgumentError SqueezeNet(pretrain = true) end @@ -258,24 +265,26 @@ end _gc() end -@testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] - m = DenseNet(sz) - @test size(m(x_224)) == (1000, 1) - if (DenseNet, sz) in PRETRAINED_MODELS - @test acctest(DenseNet(sz; pretrain = true)) - else - @test_throws ArgumentError DenseNet(sz, pretrain = true) +@testset "DenseNet" begin + @testset for sz in [121, 161, 169, 201] + m = DenseNet(sz) + @test size(m(x_224)) == (1000, 1) + if (DenseNet, sz) in PRETRAINED_MODELS + @test acctest(DenseNet(sz, pretrain = true)) + else + @test_throws ArgumentError DenseNet(sz, pretrain = true) + end + @test gradtest(m, x_224) + _gc() end - @test gradtest(m, x_224) - _gc() -end end +end @testset "MobileNets (width = $width_mult)" for width_mult in [0.5, 0.75, 1, 1.3] @testset "MobileNetv1" begin m = MobileNetv1(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv1, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv1(; pretrain = true)) + @test acctest(MobileNetv1(pretrain = true)) else @test_throws ArgumentError MobileNetv1(pretrain = true) end @@ -286,57 +295,59 @@ end end m = MobileNetv2(width_mult) @test size(m(x_224)) == (1000, 1) if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(; pretrain = true)) + @test acctest(MobileNetv2(pretrain = true)) else @test_throws ArgumentError MobileNetv2(pretrain = true) end @test gradtest(m, x_224) end _gc() - @testset "MobileNetv3" verbose=true begin @testset for config in [:small, :large] - m = MobileNetv3(config; width_mult) - @test size(m(x_224)) == (1000, 1) - if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) - else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) + @testset "MobileNetv3" verbose = true begin + @testset for config in [:small, :large] + m = MobileNetv3(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS + @test acctest(MobileNetv3(config; pretrain = true)) + else + @test_throws ArgumentError MobileNetv3(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() + end + end + @testset "MNASNet" verbose = true begin + @testset for config in [:A1, :B1] + m = MNASNet(config; width_mult) + @test size(m(x_224)) == (1000, 1) + if (MNASNet, config, width_mult) in PRETRAINED_MODELS + @test acctest(MNASNet(config; pretrain = true)) + else + @test_throws ArgumentError MNASNet(config; pretrain = true) + end + @test gradtest(m, x_224) + _gc() end + end +end + +@testset "ConvNeXt" verbose = true begin + @testset for config in [:small, :base, :large, :tiny, :xlarge] + m = ConvNeXt(config) + @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc() - end end - @testset "MNASNet" verbose=true begin @testset for config in [:A1, :B1] - m = MNASNet(config; width_mult) + end +end + +@testset "ConvMixer" verbose = true begin + @testset for config in [:small, :base, :large] + m = ConvMixer(config) @test size(m(x_224)) == (1000, 1) - if (MNASNet, config, width_mult) in PRETRAINED_MODELS - @test acctest(MNASNet(config; pretrain = true)) - else - @test_throws ArgumentError MNASNet(config; pretrain = true) - end @test gradtest(m, x_224) _gc() - end end + end end -@testset "ConvNeXt" verbose=true begin @testset for config in [ - :small, - :base, - :large, - :tiny, - :xlarge, -] - m = ConvNeXt(config) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() -end end - -@testset "ConvMixer" verbose=true begin @testset for config in [:small, :base, :large] - m = ConvMixer(config) - @test size(m(x_224)) == (1000, 1) - @test gradtest(m, x_224) - _gc() -end end - @testset "UNet" begin encoder = Metalhead.backbone(ResNet(18)) model = UNet((256, 256), 3, 10, encoder) From 4735dff7b3ba11a492e5add26e72570749d3ba0b Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 22:29:25 +0530 Subject: [PATCH 27/29] minor change --- src/convnets/unet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 333577b07..458a37f12 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -87,8 +87,8 @@ end UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) -Creates a UNet model with an encoder built of specified backbone. -By default it uses [`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder +Creates a UNet model with an encoder built of specified backbone. By default it uses +[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder. ([reference](https://arxiv.org/abs/1505.04597)). # Arguments From 3bebe5aec221fefc00a9d3e9cc7a797b5783efd8 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Sun, 22 Jan 2023 23:52:04 +0530 Subject: [PATCH 28/29] minor edit --- src/convnets/unet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 458a37f12..b4baa0cf9 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -55,7 +55,7 @@ end """ unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block, - fdownscale::Integer = 0, kwargs...) + fdownscale::Integer = 0) Creates a UNet model with specified convolutional backbone. Backbone of any Metalhead ResNet-like model can be used as encoder From 65aa5e8419236b7d8804df3d8e085192e7fd6782 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra <51750587+shivance@users.noreply.github.com> Date: Thu, 26 Jan 2023 07:55:06 +0530 Subject: [PATCH 29/29] Update src/convnets/unet.jl Co-authored-by: Kyle Daruwalla --- src/convnets/unet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index b4baa0cf9..91200c0ce 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -105,7 +105,7 @@ Creates a UNet model with an encoder built of specified backbone. By default it `UNet` does not currently support pretrained weights. -See also [`Metalhead.UNet`](@ref). +See also [`Metalhead.unet`](@ref). """ struct UNet layers::Any