Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved time to first gradient #151

Merged
merged 3 commits into from
May 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ Creates a ConvMixer model.
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad())...), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
blocks = [Chain(SkipConnection(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad()), +),
conv_bn((1, 1), planes, planes, activation; preact = true)) for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., blocks...), head)
return Chain(Chain(stem, Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Creates a single block of ConvNeXt.
- `λ`: Init value for LayerScale
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1f-6),
mlp_block(planes, 4 * planes),
Expand Down Expand Up @@ -61,7 +61,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone...), head)
return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
Expand Down
18 changes: 9 additions & 9 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ Create a Densenet bottleneck layer
"""
function dense_bottleneck(inplanes, outplanes)
inner_channels = 4 * outplanes
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true)...,
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)...)
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true),
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true))

SkipConnection(m, (mx, x) -> cat(x, mx; dims = 3))
SkipConnection(m, cat_channels)
end

"""
Expand All @@ -28,8 +28,7 @@ Create a DenseNet transition sequence
- `outplanes`: number of output feature maps
"""
transition(inplanes, outplanes) =
[conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)...,
MeanPool((2, 2))]
Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true), MeanPool((2, 2)))

"""
dense_block(inplanes, growth_rates)
Expand Down Expand Up @@ -60,20 +59,21 @@ Create a DenseNet model
- `nclasses`: the number of output classes
"""
function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
layers = conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)
layers = []
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
push!(layers, MaxPool((3, 3), stride = 2, pad = (1, 1)))

outplanes = 0
for (i, rates) in enumerate(growth_rates)
outplanes = inplanes + sum(rates)
append!(layers, dense_block(inplanes, rates))
(i != length(growth_rates)) &&
append!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
(i != length(growth_rates)) &&
push!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
inplanes = floor(Int, outplanes * reduction)
end
push!(layers, BatchNorm(outplanes, relu))

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(AdaptiveMeanPool((1, 1)),
MLUtils.flatten,
Dense(outplanes, nclasses)))
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Create an inception module for use in GoogLeNet
"""
function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)
branch1 = Chain(Conv((1, 1), inplanes => out_1x1))

branch2 = Chain(Conv((1, 1), inplanes => red_3x3),
Conv((3, 3), red_3x3 => out_3x3; pad = 1))

branch3 = Chain(Conv((1, 1), inplanes => red_5x5),
Conv((5, 5), red_5x5 => out_5x5; pad = 2))

branch4 = Chain(MaxPool((3, 3), stride=1, pad = 1),
Conv((1, 1), inplanes => pool_proj))

return Parallel(cat_channels,
branch1, branch2, branch3, branch4)
end
Expand Down
92 changes: 46 additions & 46 deletions src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ Create an Inception-v3 style-A module
- `pool_proj`: the number of output feature maps for the pooling projection
"""
function inception_a(inplanes, pool_proj)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)...)

branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)...,
conv_bn((5, 5), 48, 64; pad = 2)...)
branch1x1 = conv_bn((1, 1), inplanes, 64)

branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)...,
conv_bn((3, 3), 64, 96; pad = 1)...,
conv_bn((3, 3), 96, 96; pad = 1)...)
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48),
conv_bn((5, 5), 48, 64; pad = 2))

branch3x3 = Chain(conv_bn((1, 1), inplanes, 64),
conv_bn((3, 3), 64, 96; pad = 1),
conv_bn((3, 3), 96, 96; pad = 1))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
conv_bn((1, 1), inplanes, pool_proj)...)
conv_bn((1, 1), inplanes, pool_proj))

return Parallel(cat_channels,
branch1x1, branch5x5, branch3x3, branch_pool)
Expand All @@ -35,13 +35,13 @@ Create an Inception-v3 style-B module
- `inplanes`: number of input feature maps
"""
function inception_b(inplanes)
branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)...)
branch3x3_1 = conv_bn((3, 3), inplanes, 384; stride = 2)

branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)...,
conv_bn((3, 3), 64, 96; pad = 1)...,
conv_bn((3, 3), 96, 96; stride = 2)...)
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64),
conv_bn((3, 3), 64, 96; pad = 1),
conv_bn((3, 3), 96, 96; stride = 2))

branch_pool = Chain(MaxPool((3, 3), stride = 2))
branch_pool = MaxPool((3, 3), stride = 2)

return Parallel(cat_channels,
branch3x3_1, branch3x3_2, branch_pool)
Expand All @@ -59,20 +59,20 @@ Create an Inception-v3 style-C module
- `n`: the "grid size" (kernel size) for the convolution layers
"""
function inception_c(inplanes, inner_planes, n = 7)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)...)
branch1x1 = conv_bn((1, 1), inplanes, 192)

branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...)
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes),
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
conv_bn((n, 1), inner_planes, 192; pad = (3, 0)))

branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
conv_bn((1, n), inner_planes, 192; pad = (0, 3))...)
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes),
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
conv_bn((1, n), inner_planes, 192; pad = (0, 3)))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
conv_bn((1, 1), inplanes, 192)...)
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
conv_bn((1, 1), inplanes, 192))

return Parallel(cat_channels,
branch1x1, branch7x7_1, branch7x7_2, branch_pool)
Expand All @@ -88,15 +88,15 @@ Create an Inception-v3 style-D module
- `inplanes`: number of input feature maps
"""
function inception_d(inplanes)
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
conv_bn((3, 3), 192, 320; stride = 2)...)
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192),
conv_bn((3, 3), 192, 320; stride = 2))

branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
conv_bn((1, 7), 192, 192; pad = (0, 3))...,
conv_bn((7, 1), 192, 192; pad = (3, 0))...,
conv_bn((3, 3), 192, 192; stride = 2)...)
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192),
conv_bn((1, 7), 192, 192; pad = (0, 3)),
conv_bn((7, 1), 192, 192; pad = (3, 0)),
conv_bn((3, 3), 192, 192; stride = 2))

branch_pool = Chain(MaxPool((3, 3), stride=2))
branch_pool = MaxPool((3, 3), stride=2)

return Parallel(cat_channels,
branch3x3, branch7x7x3, branch_pool)
Expand All @@ -112,26 +112,26 @@ Create an Inception-v3 style-E module
- `inplanes`: number of input feature maps
"""
function inception_e(inplanes)
branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)...)
branch1x1 = conv_bn((1, 1), inplanes, 320)

branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)...)
branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
branch3x3_1 = conv_bn((1, 1), inplanes, 384)
branch3x3_1a = conv_bn((1, 3), 384, 384; pad = (0, 1))
branch3x3_1b = conv_bn((3, 1), 384, 384; pad = (1, 0))

branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)...,
conv_bn((3, 3), 448, 384; pad = 1)...)
branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448),
conv_bn((3, 3), 448, 384; pad = 1))
branch3x3_2a = conv_bn((1, 3), 384, 384; pad = (0, 1))
branch3x3_2b = conv_bn((3, 1), 384, 384; pad = (1, 0))

branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
conv_bn((1, 1), inplanes, 192)...)
conv_bn((1, 1), inplanes, 192))

return Parallel(cat_channels,
branch1x1,
Chain(branch3x3_1,
Parallel(cat_channels,
branch3x3_1a, branch3x3_1b)),

Chain(branch3x3_2,
Parallel(cat_channels,
branch3x3_2a, branch3x3_2b)),
Expand All @@ -150,12 +150,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
`inception3` does not currently support pretrained weights.
"""
function inception3(; nclasses = 1000)
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)...,
conv_bn((3, 3), 32, 32)...,
conv_bn((3, 3), 32, 64; pad = 1)...,
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2),
conv_bn((3, 3), 32, 32),
conv_bn((3, 3), 32, 64; pad = 1),
MaxPool((3, 3), stride = 2),
conv_bn((1, 1), 64, 80)...,
conv_bn((3, 3), 80, 192)...,
conv_bn((1, 1), 64, 80),
conv_bn((3, 3), 80, 192),
MaxPool((3, 3), stride = 2),
inception_a(192, 32),
inception_a(256, 64),
Expand Down
34 changes: 15 additions & 19 deletions src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@ function mobilenetv1(width_mult, config;
for (dw, outch, stride, repeats) in config
outch = Int(outch * width_mult)
for _ in 1:repeats
layer = if dw
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
else
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
end
append!(layers, layer)
layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
stride = stride, pad = 1) :
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
push!(layers, layer)
inchannels = outch
end
end

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(GlobalMeanPool(),
MLUtils.flatten,
Dense(inchannels, fcsize, activation),
Expand Down Expand Up @@ -120,7 +118,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
# building first layer
inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8)
layers = []
append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
push!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))

# building inverted residual blocks
for (t, c, n, s, a) in configs
Expand All @@ -136,8 +134,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
outplanes = (width_mult > 1) ? _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) :
max_width

return Chain(Chain(layers...,
conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...),
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses)))
end

Expand Down Expand Up @@ -186,7 +183,7 @@ end
(m::MobileNetv2)(x) = m.layers(x)

backbone(m::MobileNetv2) = m.layers[1]
classifier(m::MobileNetv2) = m.layers[2:end]
classifier(m::MobileNetv2) = m.layers[2]

# MobileNetv3

Expand Down Expand Up @@ -214,7 +211,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
# building first layer
inplanes = _round_channels(16 * width_mult, 8)
layers = []
append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
push!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
explanes = 0
# building inverted residual blocks
for (k, t, c, r, a, s) in configs
Expand All @@ -229,13 +226,12 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
# building last several layers
output_channel = max_width
output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel
classifier = (Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses))
classifier = Chain(Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses))

return Chain(Chain(layers...,
conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier...))
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
end

# Configurations for small and large mode for MobileNetv3
Expand Down Expand Up @@ -310,4 +306,4 @@ end
(m::MobileNetv3)(x) = m.layers(x)

backbone(m::MobileNetv3) = m.layers[1]
classifier(m::MobileNetv3) = m.layers[2:end]
classifier(m::MobileNetv3) = m.layers[2]
16 changes: 8 additions & 8 deletions src/convnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Create a basic residual block
"""
function basicblock(inplanes, outplanes, downsample = false)
stride = downsample ? 2 : 1
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false)...,
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)...)
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false),
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false))
end

"""
Expand All @@ -36,9 +36,9 @@ The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead.
"""
function bottleneck(inplanes, outplanes, downsample = false;
stride = [1, (downsample ? 2 : 1), 1])
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false)...,
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false)...,
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)...)
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false),
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false),
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false))
end


Expand Down Expand Up @@ -82,7 +82,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
inplanes = 64
baseplanes = 64
layers = []
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1)))
for (i, nrepeats) in enumerate(block_config)
# output planes within a block
Expand All @@ -102,7 +102,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
baseplanes *= 2
end

return Chain(Chain(layers...),
return Chain(Chain(layers),
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses)))
end

Expand Down Expand Up @@ -246,7 +246,7 @@ function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
model
end

# Compat with Methalhead 0.6; remove in 0.7
# Compat with Metalhead 0.6; remove in 0.7
@deprecate ResNet18(; kw...) ResNet(18; kw...)
@deprecate ResNet34(; kw...) ResNet(34; kw...)
@deprecate ResNet50(; kw...) ResNet(50; kw...)
Expand Down
Loading