Skip to content

Commit

Permalink
add an toy example for nnunetv2.utilities.get_network_from_plans
Browse files Browse the repository at this point in the history
  • Loading branch information
AIboy996 committed Sep 19, 2024
1 parent 491feab commit aa74c3a
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions nnunetv2/utilities/get_network_from_plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,34 @@ def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import,
network.apply(network.initialize)

return network

if __name__ == "__main__":
import torch

model = get_network_from_plans(
arch_class_name="dynamic_network_architectures.architectures.unet.ResidualEncoderUNet",
arch_kwargs={
"n_stages": 7,
"features_per_stage": [32, 64, 128, 256, 512, 512, 512],
"conv_op": "torch.nn.modules.conv.Conv2d",
"kernel_sizes": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
"strides": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
"n_blocks_per_stage": [1, 3, 4, 6, 6, 6, 6],
"n_conv_per_stage_decoder": [1, 1, 1, 1, 1, 1],
"conv_bias": True,
"norm_op": "torch.nn.modules.instancenorm.InstanceNorm2d",
"norm_op_kwargs": {"eps": 1e-05, "affine": True},
"dropout_op": None,
"dropout_op_kwargs": None,
"nonlin": "torch.nn.LeakyReLU",
"nonlin_kwargs": {"inplace": True},
},
arch_kwargs_req_import=["conv_op", "norm_op", "dropout_op", "nonlin"],
input_channels=1,
output_channels=4,
allow_init=True,
deep_supervision=True,
)
data = torch.rand((8, 1, 256, 256))
target = torch.rand(size=(8, 1, 256, 256))
outputs = model(data) # this should be a list of torch.Tensor

0 comments on commit aa74c3a

Please sign in to comment.