Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba54cf0
model implemented
shivance Dec 27, 2022
11c50d9
adding documentation
shivance Dec 27, 2022
ca73586
ran juliaformatter
shivance Dec 28, 2022
552a8fd
removed custom forward pass using Parallel
shivance Jan 1, 2023
c577aed
removing _random_normal
shivance Jan 1, 2023
fb642c4
incorporating suggested changes
shivance Jan 2, 2023
7c7b1ee
Revert "ran juliaformatter"
shivance Jan 3, 2023
99f07ad
adapting to fastai's unet impl
shivance Jan 10, 2023
fc756d9
undoing utilities formatting
shivance Jan 10, 2023
60b082c
formatting + documentation + func signature
shivance Jan 10, 2023
2f1cc6d
adding unit tests for unet
shivance Jan 10, 2023
8d2ba2b
configuring CI
shivance Jan 10, 2023
77a3148
configuring CI
shivance Jan 10, 2023
8aebd14
Merge branch 'master' into unet
shivance Jan 10, 2023
429096b
Update convnets.jl
shivance Jan 10, 2023
d761126
Update convnets.jl
shivance Jan 10, 2023
1b5d2b7
updated test
shivance Jan 11, 2023
354e3c4
minor fixes
shivance Jan 12, 2023
6494be7
typing fix
shivance Jan 12, 2023
2d68f61
Update src/utilities.jl
shivance Jan 12, 2023
627480f
fixing ci
shivance Jan 12, 2023
4012fb2
renaming:
shivance Jan 16, 2023
016cef4
fixing test
shivance Jan 22, 2023
6097c57
Update .github/workflows/CI.yml
shivance Jan 22, 2023
98b4c30
Update src/convnets/unet.jl
shivance Jan 22, 2023
54c334f
Update src/convnets/unet.jl
shivance Jan 22, 2023
4fae8d6
incorporating suggestions
shivance Jan 22, 2023
4735dff
minor change
shivance Jan 22, 2023
3bebe5a
minor edit
shivance Jan 22, 2023
65aa5e8
Update src/convnets/unet.jl
shivance Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux
using Documenter, Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays,
DataAugmentation, Flux

DocMeta.setdocmeta!(Metalhead, :DocTestSetup, :(using Metalhead); recursive = true)

makedocs(modules = [Metalhead, Artifacts, LazyArtifacts, Images, OneHotArrays, DataAugmentation, Flux],
makedocs(;
modules = [
Metalhead,
Artifacts,
LazyArtifacts,
Images,
OneHotArrays,
DataAugmentation,
Flux,
],
sitename = "Metalhead.jl",
doctest = false,
pages = ["Home" => "index.md",
"Tutorials" => [
"tutorials/quickstart.md",
],
"Developer guide" => "contributing.md",
"API reference" => [
"api/reference.md",
],
],
format = Documenter.HTML(
canonical = "https://fluxml.ai/Metalhead.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
)
"Tutorials" => [
"tutorials/quickstart.md",
],
"Developer guide" => "contributing.md",
"API reference" => [
"api/reference.md",
],
],
format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"))

deploydocs(repo = "github.com/FluxML/Metalhead.jl.git",
deploydocs(; repo = "github.com/FluxML/Metalhead.jl.git",
target = "build",
push_preview = true)
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
8 changes: 5 additions & 3 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using MLUtils
using PartialFunctions
using Random

import Functors
using Functors: Functors

include("utilities.jl")

Expand All @@ -28,6 +28,8 @@ include("convnets/builders/stages.jl")
## AlexNet and VGG
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
## Unet
include("convnets/unet.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/res2net.jl")
Expand Down Expand Up @@ -66,7 +68,7 @@ include("vit-based/vit.jl")
# Load pretrained weights
include("pretrain.jl")

export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, UNet,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
Expand All @@ -76,7 +78,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
MLPMixer, ResMLP, gMLP, ViT

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
for T in (:AlexNet, :VGG, :UNet, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Create a `AlexNet`.
- `nclasses`: the number of output classes

!!! warning
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

`AlexNet` does not currently support pretrained weights.

See also [`alexnet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/builders/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Creates a generic ResNet-like model.

!!! info

This is a very generic, flexible but low level function that can be used to create any of the ResNet
variants. For a more user friendly function, see [`Metalhead.resnet`](@ref).

Expand Down
2 changes: 1 addition & 1 deletion src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Creates a ConvNeXt model.
- `nclasses`: number of output classes

!!! warning

`ConvNeXt` does not currently support pretrained weights.

See also [`Metalhead.convnext`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Create a DenseNet model with specified configuration. Currently supported values
Set `pretrain = true` to load the model with pre-trained weights for ImageNet.

!!! warning

`DenseNet` does not currently support pretrained weights.

See also [`Metalhead.densenet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/efficientnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
- `nclasses`: number of output classes.

!!! warning

EfficientNet does not currently support pretrained weights.

See also [`Metalhead.efficientnet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/efficientnets/efficientnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)).
- `nclasses`: number of output classes

!!! warning

`EfficientNetv2` does not currently support pretrained weights.

See also [`efficientnet`](#).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/inceptions/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Create an Inception-v1 model (commonly referred to as `GoogLeNet`)
- `bias`: set to `true` to use bias in the convolution layers

!!! warning

`GoogLeNet` does not currently support pretrained weights.

See also [`Metalhead.googlenet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/inceptions/inceptionresnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ Creates an InceptionResNetv2 model.
- `nclasses`: the number of output classes.

!!! warning

`InceptionResNetv2` does not currently support pretrained weights.

See also [`Metalhead.inceptionresnetv2`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/inceptions/inceptionv3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
- `nclasses`: the number of output classes

!!! warning

`Inceptionv3` does not currently support pretrained weights.

See also [`Metalhead.inceptionv3`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/inceptions/inceptionv4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Creates an Inceptionv4 model.
- `nclasses`: the number of output classes.

!!! warning

`Inceptionv4` does not currently support pretrained weights.

See also [`Metalhead.inceptionv4`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/inceptions/xception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Creates an Xception model.
- `nclasses`: the number of output classes.

!!! warning

`Xception` does not currently support pretrained weights.

See also [`Metalhead.xception`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenets/mnasnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Creates a MNASNet model with the specified configuration.
- `nclasses`: The number of output classes

!!! warning

`MNASNet` does not currently support pretrained weights.

See also [`Metalhead.mnasnet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenets/mobilenetv1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Create a MobileNetv1 model with the baseline configuration
- `nclasses`: The number of output classes

!!! warning

`MobileNetv1` does not currently support pretrained weights.

See also [`Metalhead.mobilenetv1`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenets/mobilenetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Create a MobileNetv2 model with the specified configuration.
- `nclasses`: The number of output classes

!!! warning

`MobileNetv2` does not currently support pretrained weights.

See also [`Metalhead.mobilenetv2`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenets/mobilenetv3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
- `nclasses`: the number of output classes

!!! warning

`MobileNetv3` does not currently support pretrained weights.

See also [`Metalhead.mobilenetv3`](@ref).
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ If `outplanes > inplanes`, it maps the input to `outplanes` channels using a 1x1
layer and zero padding.

!!! warning

This does not currently support the scenario where `inplanes > outplanes`.

# Arguments
Expand Down Expand Up @@ -237,7 +237,7 @@ on how to use this function.
# Arguments

- `stem_type`: The type of stem to be built. One of `[:default, :deep, :deep_tiered]`.

+ `:default`: Builds a stem based on the default ResNet stem, which consists of a single
7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling
layer with stride 2.
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/resnets/res2net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Creates a Res2Net model with the specified depth, scale, and base width.
- `nclasses`: the number of output classes

!!! warning

`Res2Net` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](@ref).
Expand Down Expand Up @@ -64,7 +64,7 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina
- `nclasses`: the number of output classes

!!! warning

`Res2NeXt` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/resnets/resnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width.

- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet.
Supported configurations are:

+ depth 50, cardinality of 32 and base width of 4.
+ depth 101, cardinality of 32 and base width of 8.
+ depth 101, cardinality of 64 and base width of 4.
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/resnets/seresnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Creates a SEResNet model with the specified depth.
- `nclasses`: the number of output classes

!!! warning

`SEResNet` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](@ref).
Expand Down Expand Up @@ -55,7 +55,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width.
- `nclasses`: the number of output classes

!!! warning

`SEResNeXt` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](@ref).
Expand Down
77 changes: 77 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3))
return Chain(; conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)),
norm1 = BatchNorm(out_chs, relu),
conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)),
norm2 = BatchNorm(out_chs, relu))
end

function upconv_block(in_chs::Int, out_chs::Int, kernel = (2, 2))
return ConvTranspose(kernel, in_chs => out_chs; stride = (2, 2))
end

function unet(in_channels::Integer = 3, out_channels::Integer = in_channels,
features::Integer = 32)
encoder_conv = []
push!(encoder_conv,
unet_block(in_channels, features))

append!(encoder_conv,
[unet_block(features * 2^i, features * 2^(i + 1)) for i in 0:2])

encoder_conv = Chain(encoder_conv)
encoder_pool = [Chain(encoder_conv[i], MaxPool((2, 2); stride = (2, 2))) for i in 1:4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can combine the MaxPool and the unet_block functions into a separate block? That should help writing code with less of the indexing.


bottleneck = unet_block(features * 8, features * 16)
layers = Chain(encoder_conv, bottleneck)

upconv = Chain([upconv_block(features * 2^(i + 1), features * 2^i)
for i in 0:3]...)

concat_layer = Chain([Parallel(cat_channels,
encoder_pool[i],
upconv[i])
for i in 1:4]...)

decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]...)

layers = Chain(layers, decoder_layer)

decoder = Chain([Chain([
concat_layer[i],
decoder_layer[5-i]])
for i in 4:-1:1]...)

final_conv = Conv((1, 1), features => out_channels, σ)

decoder = Chain(decoder, final_conv)

return layers
Copy link
Contributor

@pri1311 pri1311 Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok figured out, you are returning layers. Which does not have decoder and final_conv chained to it.

Copy link
Contributor

@pri1311 pri1311 Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]...)
layers = Chain(layers, decoder_layer)
decoder = Chain([Chain([
concat_layer[i],
decoder_layer[5-i]])
for i in 4:-1:1]...)
final_conv = Conv((1, 1), features => out_channels, σ)
decoder = Chain(decoder, final_conv)
return layers
decoder_layer = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i in 3:-1:0]...)
decoder = Chain([Chain([
concat_layer[i],
decoder_layer[5-i]])
for i in 4:-1:1]...)
final_conv = Conv((1, 1), features => out_channels, σ)
layers = Chain(layers, decoder, final_conv)
return layers

Is this correct? Aah ok I think this gives dimenion mis match

end

"""
UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = inplanes)

Create a UNet model
([reference](https://arxiv.org/abs/1505.04597v1))

# Arguments
- `in_channels`: The number of input channels
- `inplanes`: The number of input features to the network
- `outplanes`: The number of output features

!!! warning

`UNet` does not currently support pretrained weights.
"""
struct UNet
layers::Any
end
@functor UNet

function UNet(in_channels::Integer = 3, inplanes::Integer = 32,
outplanes::Integer = inplanes)
layers = unet(in_channels, inplanes, outplanes)
return UNet(layers)
end

(m::UNet)(x::AbstractArray) = m.layers(x)
Loading