Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba54cf0
model implemented
shivance Dec 27, 2022
11c50d9
adding documentation
shivance Dec 27, 2022
ca73586
ran juliaformatter
shivance Dec 28, 2022
552a8fd
removed custom forward pass using Parallel
shivance Jan 1, 2023
c577aed
removing _random_normal
shivance Jan 1, 2023
fb642c4
incorporating suggested changes
shivance Jan 2, 2023
7c7b1ee
Revert "ran juliaformatter"
shivance Jan 3, 2023
99f07ad
adapting to fastai's unet impl
shivance Jan 10, 2023
fc756d9
undoing utilities formatting
shivance Jan 10, 2023
60b082c
formatting + documentation + func signature
shivance Jan 10, 2023
2f1cc6d
adding unit tests for unet
shivance Jan 10, 2023
8d2ba2b
configuring CI
shivance Jan 10, 2023
77a3148
configuring CI
shivance Jan 10, 2023
8aebd14
Merge branch 'master' into unet
shivance Jan 10, 2023
429096b
Update convnets.jl
shivance Jan 10, 2023
d761126
Update convnets.jl
shivance Jan 10, 2023
1b5d2b7
updated test
shivance Jan 11, 2023
354e3c4
minor fixes
shivance Jan 12, 2023
6494be7
typing fix
shivance Jan 12, 2023
2d68f61
Update src/utilities.jl
shivance Jan 12, 2023
627480f
fixing ci
shivance Jan 12, 2023
4012fb2
renaming:
shivance Jan 16, 2023
016cef4
fixing test
shivance Jan 22, 2023
6097c57
Update .github/workflows/CI.yml
shivance Jan 22, 2023
98b4c30
Update src/convnets/unet.jl
shivance Jan 22, 2023
54c334f
Update src/convnets/unet.jl
shivance Jan 22, 2023
4fae8d6
incorporating suggestions
shivance Jan 22, 2023
4735dff
minor change
shivance Jan 22, 2023
3bebe5a
minor edit
shivance Jan 22, 2023
65aa5e8
Update src/convnets/unet.jl
shivance Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
81 changes: 81 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
function unet_block(in_chs::Int, out_chs::Int, kernel = (3, 3))
Chain(conv1 = Conv(kernel, in_chs => out_chs, pad = (1, 1); init = _random_normal),
norm1 = BatchNorm(out_chs, relu),
conv2 = Conv(kernel, out_chs => out_chs, pad = (1, 1); init = _random_normal),
norm2 = BatchNorm(out_chs, relu))
end

function UpConvBlock(in_chs::Int, out_chs::Int, kernel = (2, 2))
Chain(convtranspose = ConvTranspose(kernel, in_chs => out_chs, stride = (2, 2); init = _random_normal),
norm = BatchNorm(out_chs, relu))
end

"""
UNet(inplanes::Integer = 3, outplanes::Integer = 1, init_features::Integer = 32)

Create a UNet model
([reference](https://arxiv.org/abs/1505.04597v1))

# Arguments
- `in_channels`: The number of input channels
- `inplanes`: The number of input planes to the network
- `outplanes`: The number of output features

!!! warning

`UNet` does not currently support pretrained weights.
"""
struct UNet
encoder::Any
decoder::Any
upconv::Any
pool::Any
bottleneck::Any
final_conv::Any
end
@functor UNet

function UNet(in_channels::Integer = 3, inplanes::Integer = 32, outplanes::Integer = 1)

features = inplanes

encoder_layers = []
append!(encoder_layers, [unet_block(in_channels, features)])
append!(encoder_layers, [unet_block(features * 2^i, features * 2^(i + 1)) for i ∈ 0:2])

encoder = Chain(encoder_layers)

decoder = Chain([unet_block(features * 2^(i + 1), features * 2^i) for i ∈ 0:3])

pool = Chain([MaxPool((2, 2), stride = (2, 2)) for _ ∈ 1:4])

upconv = Chain([UpConvBlock(features * 2^(i + 1), features * 2^i) for i ∈ 3:-1:0])

bottleneck = _block(features * 8, features * 16)

final_conv = Conv((1, 1), features => outplanes)

UNet(encoder, decoder, upconv, pool, bottleneck, final_conv)
end

function (u::UNet)(x::AbstractArray)
enc_out = []

out = x
for i ∈ 1:4
out = u.encoder[i](out)
push!(enc_out, out)

out = u.pool[i](out)
end

out = u.bottleneck(out)

for i ∈ 4:-1:1
out = u.upconv[5-i](out)
out = cat(out, enc_out[i], dims = 3)
out = u.decoder[i](out)
end

σ(u.final_conv(out))
end