From 8bb3407a2c2d454a542398ddd8f2234e32abfdb1 Mon Sep 17 00:00:00 2001 From: JiaDingCN Date: Tue, 26 Jul 2022 07:37:41 +0000 Subject: [PATCH] version1 --- LICENSE | 220 +++ README.md | 44 + benchmark.py | 80 + .../50eps/r50_deformable_detr.sh | 15 + ...ble_detr_plus_iterative_bbox_refinement.sh | 17 + .../50eps/r50_deformable_detr_single_scale.sh | 16 + .../r50_deformable_detr_single_scale_dc5.sh | 17 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 18 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 23 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 19 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 23 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 19 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 23 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 19 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 22 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 25 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 25 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 23 + ...ive_bbox_refinement_plus_plus_two_stage.sh | 24 + datasets/__init__.py | 34 + datasets/coco.py | 195 +++ datasets/coco_eval.py | 273 ++++ datasets/coco_panoptic.py | 107 ++ datasets/data_prefetcher.py | 70 + datasets/panoptic_eval.py | 52 + datasets/samplers.py | 139 ++ datasets/torchvision_datasets/__init__.py | 7 + datasets/torchvision_datasets/coco.py | 84 ++ datasets/transforms.py | 294 ++++ engine.py | 285 ++++ main.py | 532 +++++++ mmcv_custom/__init__.py | 6 + mmcv_custom/checkpoint.py | 508 +++++++ mmcv_custom/runner/__init__.py | 7 + mmcv_custom/runner/checkpoint.py | 81 + mmcv_custom/runner/epoch_based_runner.py | 103 ++ models/__init__.py | 15 + models/backbone.py | 269 ++++ models/deformable_detr.py | 656 ++++++++ models/deformable_transformer.py | 632 ++++++++ models/matcher.py | 124 ++ models/ops/functions/__init__.py | 10 + models/ops/functions/ms_deform_attn_func.py | 110 ++ models/ops/make.sh | 10 + models/ops/modules/__init__.py | 9 + models/ops/modules/ms_deform_attn.py | 162 ++ models/ops/setup.py | 71 + models/ops/src/cpu/ms_deform_attn_cpu.cpp | 41 + models/ops/src/cpu/ms_deform_attn_cpu.h | 33 + models/ops/src/cuda/ms_deform_attn_cuda.cu | 153 ++ models/ops/src/cuda/ms_deform_attn_cuda.h | 30 + models/ops/src/cuda/ms_deform_im2col_cuda.cuh | 1327 +++++++++++++++++ models/ops/src/ms_deform_attn.h | 62 + models/ops/src/vision.cpp | 16 + models/ops/test.py | 89 ++ models/position_encoding.py | 113 ++ models/segmentation.py | 427 ++++++ models/swin_transformer.py | 743 +++++++++ requirements.txt | 5 + tools/launch.py | 216 +++ tools/run_dist_launch.sh | 29 + tools/run_dist_slurm.sh | 33 + util/__init__.py | 8 + util/box_ops.py | 94 ++ util/misc.py | 514 +++++++ util/plot_utils.py | 111 ++ 95 files changed, 10181 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 benchmark.py create mode 100644 configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr.sh create mode 100644 configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement.sh create mode 100644 configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale.sh create mode 100644 configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale_dc5.sh create mode 100644 configs/two_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/12eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/24eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/24eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/36eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/36eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-baseline/50eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group1_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group2_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group3_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group4_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group5_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1800_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh create mode 100644 datasets/__init__.py create mode 100644 datasets/coco.py create mode 100644 datasets/coco_eval.py create mode 100644 datasets/coco_panoptic.py create mode 100644 datasets/data_prefetcher.py create mode 100644 datasets/panoptic_eval.py create mode 100644 datasets/samplers.py create mode 100644 datasets/torchvision_datasets/__init__.py create mode 100644 datasets/torchvision_datasets/coco.py create mode 100644 datasets/transforms.py create mode 100644 engine.py create mode 100644 main.py create mode 100644 mmcv_custom/__init__.py create mode 100644 mmcv_custom/checkpoint.py create mode 100644 mmcv_custom/runner/__init__.py create mode 100644 mmcv_custom/runner/checkpoint.py create mode 100644 mmcv_custom/runner/epoch_based_runner.py create mode 100644 models/__init__.py create mode 100644 models/backbone.py create mode 100644 models/deformable_detr.py create mode 100644 models/deformable_transformer.py create mode 100644 models/matcher.py create mode 100644 models/ops/functions/__init__.py create mode 100644 models/ops/functions/ms_deform_attn_func.py create mode 100644 models/ops/make.sh create mode 100644 models/ops/modules/__init__.py create mode 100644 models/ops/modules/ms_deform_attn.py create mode 100644 models/ops/setup.py create mode 100644 models/ops/src/cpu/ms_deform_attn_cpu.cpp create mode 100644 models/ops/src/cpu/ms_deform_attn_cpu.h create mode 100644 models/ops/src/cuda/ms_deform_attn_cuda.cu create mode 100644 models/ops/src/cuda/ms_deform_attn_cuda.h create mode 100644 models/ops/src/cuda/ms_deform_im2col_cuda.cuh create mode 100644 models/ops/src/ms_deform_attn.h create mode 100644 models/ops/src/vision.cpp create mode 100644 models/ops/test.py create mode 100644 models/position_encoding.py create mode 100644 models/segmentation.py create mode 100644 models/swin_transformer.py create mode 100644 requirements.txt create mode 100644 tools/launch.py create mode 100644 tools/run_dist_launch.sh create mode 100644 tools/run_dist_slurm.sh create mode 100644 util/__init__.py create mode 100644 util/box_ops.py create mode 100644 util/misc.py create mode 100644 util/plot_utils.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..522e5bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,220 @@ +Copyright (c) 2020 SenseTime. All Rights Reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 SenseTime + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +DETR + +Copyright 2020 - present, Facebook, Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4c9526f --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +# Modified files + +## To support swin backbones +* models/backbone.py +* models/swin_transformer.py +* mmcv_custom + +## To support eval in the training set +* datasets/coco.py +* datasets/\_\_init\_\_.py + +## To support Hybird-branch, tricks and checkpoint +* main.py +* engine.py +* models/deformable_detr.py +* models/deformable_transformer.py + +## To support fp16 +* models/ops/modules/ms_deform_attn.py +* models/ops/functions/ms_deform_attn_func.py + +## To fix a pytorch version bug +* util/misc.py + +# Addictional packages needed + +* wandb: for logging +* mmdet: for swin backbones +* mmcv: for swin backbones +* timm: for swin backbones + +# To train a model + +```Bash +GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 \ + --coco_path +``` + +# To eval a model + +```Bash +GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 \ + --coco_path --eval +``` \ No newline at end of file diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..f510830 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,80 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +Benchmark inference speed of Deformable DETR. +""" +import os +import time +import argparse + +import torch + +from main import get_args_parser as get_main_args_parser +from models import build_model +from datasets import build_dataset +from util.misc import nested_tensor_from_tensor_list + + +def get_benckmark_arg_parser(): + parser = argparse.ArgumentParser("Benchmark inference speed of Deformable DETR.") + parser.add_argument( + "--num_iters", type=int, default=300, help="total iters to benchmark speed" + ) + parser.add_argument( + "--warm_iters", + type=int, + default=5, + help="ignore first several iters that are very slow", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="batch size in inference" + ) + parser.add_argument("--resume", type=str, help="load the pre-trained checkpoint") + return parser + + +@torch.no_grad() +def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5): + ts = [] + for iter_ in range(num_iters): + torch.cuda.synchronize() + t_ = time.perf_counter() + model(inputs) + torch.cuda.synchronize() + t = time.perf_counter() - t_ + if iter_ >= warm_iters: + ts.append(t) + print(ts) + return sum(ts) / len(ts) + + +def benchmark(): + args, _ = get_benckmark_arg_parser().parse_known_args() + main_args = get_main_args_parser().parse_args(_) + assert ( + args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0 + ) + assert args.batch_size > 0 + assert args.resume is None or os.path.exists(args.resume) + dataset = build_dataset("val", main_args) + model, _, _ = build_model(main_args) + model.cuda() + model.eval() + if args.resume is not None: + ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt["model"]) + inputs = nested_tensor_from_tensor_list( + [dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)] + ) + t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters) + return 1.0 / t * args.batch_size + + +if __name__ == "__main__": + fps = benchmark() + print(f"Inference Speed: {fps:.1f} FPS") + diff --git a/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr.sh b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr.sh new file mode 100644 index 0000000..33d26ee --- /dev/null +++ b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/one_stage/deformable-detr-baseline/12eps/r50_deformable_detr +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement.sh b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement.sh new file mode 100644 index 0000000..05ece39 --- /dev/null +++ b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/one_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --dim_feedforward 2048 \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale.sh b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale.sh new file mode 100644 index 0000000..3503f94 --- /dev/null +++ b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/one_stage/deformable-detr-baseline/12eps/r50_deformable_detr_single_scale +PY_ARGS=${@:1} + +python -u main.py \ + --num_feature_levels 1 \ + --output_dir ${EXP_DIR} \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale_dc5.sh b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale_dc5.sh new file mode 100644 index 0000000..ac06233 --- /dev/null +++ b/configs/one_stage/deformable-detr-baseline/50eps/r50_deformable_detr_single_scale_dc5.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/one_stage/deformable-detr-baseline/12eps/r50_deformable_detr_single_scale_dc5 +PY_ARGS=${@:1} + +python -u main.py \ + --num_feature_levels 1 \ + --dilation \ + --output_dir ${EXP_DIR} \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..6272f41 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/12eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 12 \ + --lr_drop 11 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/12eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/12eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..75b7ec7 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/12eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/12eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 1800 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 12 \ + --lr_drop 11 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/24eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/24eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..e02f091 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/24eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/24eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 24 \ + --lr_drop 20 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/24eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/24eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..42e3127 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/24eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/24eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 1800 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 24 \ + --lr_drop 20 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/36eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/36eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..415dc70 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/36eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/36eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 36 \ + --lr_drop 30 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/36eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/36eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..8d5892f --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/36eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/36eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 1800 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 36 \ + --lr_drop 30 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..b8d9711 --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/50eps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 300 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-baseline/50eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-baseline/50eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..d9f4cbd --- /dev/null +++ b/configs/two_stage/deformable-detr-baseline/50eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-baseline/50eps/r50_n1800_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --num_queries_one2one 1800 \ + --num_queries_one2many 0 \ + --k_one2many 0 \ + --epochs 50 \ + --lr_drop 40 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..3f0f7ab --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone resnet101 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..36889e3 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 0.1 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..04df85d --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 0.2 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..d051752 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda0.5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 0.5 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group1_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group1_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..79ac204 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group1_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group1_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 300 \ + --k_one2many 1 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group2_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group2_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..8795ef9 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group2_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group2_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 600 \ + --k_one2many 2 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group3_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group3_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..c73abf6 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group3_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group3_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 900 \ + --k_one2many 3 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group4_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group4_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..5e42904 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group4_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group4_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1200 \ + --k_one2many 4 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group5_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group5_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..ac6a554 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group5_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group5_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 5 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..e57cc93 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1200_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1200 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..c2ce2c6 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..0e805c0 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1800_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1800_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..ff40542 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1800_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t1800_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1800 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..d786c6a --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t300_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 300 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..6dce16a --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t600_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 600 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..28aadb9 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda1_group6_t900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 900 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..edd3e9c --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda2_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 2.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..151ce5a --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/r50_hybrid_branch_lambda5_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 5.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..e75f70c --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_large \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_large_patch4_window7_224_22k.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..2f47008 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 900 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_large \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_large_patch4_window7_224_22k.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..65aac09 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_small \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_small_patch4_window7_224_22k.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..7730734 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_small \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_small_patch4_window7_224.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..e4a96dc --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/12eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 12 \ + --lr_drop 11 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_tiny \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_tiny_patch4_window7_224.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..a4e2327 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 24 \ + --lr_drop 20 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..00484c7 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 24 \ + --lr_drop 20 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..64d795a --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/configs/two_stage/deformable-detr-hybrid-branch/36eps/r101_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone resnet101 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..132bb95 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/24eps/r50_hybrid_branch_lambda1_group6_t1500_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..84c9867 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..85fc74e --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_large \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_large_patch4_window7_224_22k.pth \ + --drop_path_rate 0.5 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..51c57f2 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/swin/drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 900 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_large \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_large_patch4_window7_224_22k.pth \ + --drop_path_rate 0.5 \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..ff82e8b --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_22k_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_small \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_small_patch4_window7_224_22k.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..d5574c0 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_small_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_small \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_small_patch4_window7_224.pth \ + ${PY_ARGS} diff --git a/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100644 index 0000000..635cb82 --- /dev/null +++ b/configs/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/two_stage/deformable-detr-hybrid-branch/36eps/swin/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + --dim_feedforward 2048 \ + --epochs 36 \ + --lr_drop 30 \ + --num_queries_one2one 300 \ + --num_queries_one2many 1500 \ + --k_one2many 6 \ + --lambda_one2many 1.0 \ + --dropout 0.0 \ + --mixed_selection \ + --look_forward_twice \ + --backbone swin_tiny \ + --pretrained_backbone_path /mnt/pretrained_backbone/swin_tiny_patch4_window7_224.pth \ + ${PY_ARGS} diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..c5ac24b --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,34 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import torch.utils.data +from .torchvision_datasets import CocoDetection + +from .coco import build as build_coco + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + # if isinstance(dataset, torchvision.datasets.CocoDetection): + # break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, CocoDetection): + return dataset.coco + + +def build_dataset(image_set, args, eval_in_training_set=False): + if args.dataset_file == "coco": + return build_coco(image_set, args, eval_in_training_set) + if args.dataset_file == "coco_panoptic": + # to avoid making panopticapi required for coco + from .coco_panoptic import build as build_coco_panoptic + + return build_coco_panoptic(image_set, args) + raise ValueError(f"dataset {args.dataset_file} not supported") diff --git a/datasets/coco.py b/datasets/coco.py new file mode 100644 index 0000000..c7fb7a1 --- /dev/null +++ b/datasets/coco.py @@ -0,0 +1,195 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +COCO dataset which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" +from pathlib import Path + +import torch +import torch.utils.data +from pycocotools import mask as coco_mask + +from .torchvision_datasets import CocoDetection as TvCocoDetection +from util.misc import get_local_rank, get_local_size +import datasets.transforms as T + + +class CocoDetection(TvCocoDetection): + def __init__( + self, + img_folder, + ann_file, + transforms, + return_masks, + cache_mode=False, + local_rank=0, + local_size=1, + ): + super(CocoDetection, self).__init__( + img_folder, + ann_file, + cache_mode=cache_mode, + local_rank=local_rank, + local_size=local_size, + ) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = {"image_id": image_id, "annotations": target} + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor( + [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno] + ) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + + +def make_coco_transforms(image_set): + + normalize = T.Compose( + [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] + ) + + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + + if image_set == "train": + return T.Compose( + [ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=1333), + T.Compose( + [ + T.RandomResize([400, 500, 600]), + T.RandomSizeCrop(384, 600), + T.RandomResize(scales, max_size=1333), + ] + ), + ), + normalize, + ] + ) + + if image_set == "val": + return T.Compose([T.RandomResize([800], max_size=1333), normalize,]) + + raise ValueError(f"unknown {image_set}") + + +def build(image_set, args, eval_in_training_set): + root = Path(args.coco_path) + assert root.exists(), f"provided COCO path {root} does not exist" + mode = "instances" + PATHS = { + "train": (root / "train2017", root / "annotations" / f"{mode}_train2017.json"), + "val": (root / "val2017", root / "annotations" / f"{mode}_val2017.json"), + } + + img_folder, ann_file = PATHS[image_set] + if eval_in_training_set: + image_set = "val" + print("use validation dataset transforms") + dataset = CocoDetection( + img_folder, + ann_file, + transforms=make_coco_transforms(image_set), + return_masks=args.masks, + cache_mode=args.cache_mode, + local_rank=get_local_rank(), + local_size=get_local_size(), + ) + return dataset diff --git a/datasets/coco_eval.py b/datasets/coco_eval.py new file mode 100644 index 0000000..15fdc08 --- /dev/null +++ b/datasets/coco_eval.py @@ -0,0 +1,273 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from util.misc import all_gather + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval( + self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type] + ) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode( + np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F") + )[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "keypoints": keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + """ + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + print( + "useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType) + ) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/datasets/coco_panoptic.py b/datasets/coco_panoptic.py new file mode 100644 index 0000000..e856e49 --- /dev/null +++ b/datasets/coco_panoptic.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from panopticapi.utils import rgb2id +from util.box_ops import masks_to_boxes + +from .coco import make_coco_transforms + + +class CocoPanoptic: + def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): + with open(ann_file, 'r') as f: + self.coco = json.load(f) + + # sort 'images' field so that they are aligned with 'annotations' + # i.e., in alphabetical order + self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) + # sanity check + if "annotations" in self.coco: + for img, ann in zip(self.coco['images'], self.coco['annotations']): + assert img['file_name'][:-4] == ann['file_name'][:-4] + + self.img_folder = img_folder + self.ann_folder = ann_folder + self.ann_file = ann_file + self.transforms = transforms + self.return_masks = return_masks + + def __getitem__(self, idx): + ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] + img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') + ann_path = Path(self.ann_folder) / ann_info['file_name'] + + img = Image.open(img_path).convert('RGB') + w, h = img.size + if "segments_info" in ann_info: + masks = np.asarray(Image.open(ann_path), dtype=np.uint32) + masks = rgb2id(masks) + + ids = np.array([ann['id'] for ann in ann_info['segments_info']]) + masks = masks == ids[:, None, None] + + masks = torch.as_tensor(masks, dtype=torch.uint8) + labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) + + target = {} + target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) + if self.return_masks: + target['masks'] = masks + target['labels'] = labels + + target["boxes"] = masks_to_boxes(masks) + + target['size'] = torch.as_tensor([int(h), int(w)]) + target['orig_size'] = torch.as_tensor([int(h), int(w)]) + if "segments_info" in ann_info: + for name in ['iscrowd', 'area']: + target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.coco['images']) + + def get_height_and_width(self, idx): + img_info = self.coco['images'][idx] + height = img_info['height'] + width = img_info['width'] + return height, width + + +def build(image_set, args): + img_folder_root = Path(args.coco_path) + ann_folder_root = Path(args.coco_panoptic_path) + assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' + assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' + mode = 'panoptic' + PATHS = { + "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), + "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), + } + + img_folder, ann_file = PATHS[image_set] + img_folder_path = img_folder_root / img_folder + ann_folder = ann_folder_root / f'{mode}_{img_folder}' + ann_file = ann_folder_root / ann_file + + dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, + transforms=make_coco_transforms(image_set), return_masks=args.masks) + + return dataset diff --git a/datasets/data_prefetcher.py b/datasets/data_prefetcher.py new file mode 100644 index 0000000..7d28d9f --- /dev/null +++ b/datasets/data_prefetcher.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch + +def to_cuda(samples, targets, device): + samples = samples.to(device, non_blocking=True) + targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] + return samples, targets + +class data_prefetcher(): + def __init__(self, loader, device, prefetch=True): + self.loader = iter(loader) + self.prefetch = prefetch + self.device = device + if prefetch: + self.stream = torch.cuda.Stream() + self.preload() + + def preload(self): + try: + self.next_samples, self.next_targets = next(self.loader) + except StopIteration: + self.next_samples = None + self.next_targets = None + return + # if record_stream() doesn't work, another option is to make sure device inputs are created + # on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') + # Need to make sure the memory allocated for next_* is not still in use by the main stream + # at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + + def next(self): + if self.prefetch: + torch.cuda.current_stream().wait_stream(self.stream) + samples = self.next_samples + targets = self.next_targets + if samples is not None: + samples.record_stream(torch.cuda.current_stream()) + if targets is not None: + for t in targets: + for k, v in t.items(): + v.record_stream(torch.cuda.current_stream()) + self.preload() + else: + try: + samples, targets = next(self.loader) + samples, targets = to_cuda(samples, targets, self.device) + except StopIteration: + samples = None + targets = None + return samples, targets diff --git a/datasets/panoptic_eval.py b/datasets/panoptic_eval.py new file mode 100644 index 0000000..0dabffd --- /dev/null +++ b/datasets/panoptic_eval.py @@ -0,0 +1,52 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +import os + +import util.misc as utils + +try: + from panopticapi.evaluation import pq_compute +except ImportError: + pass + + +class PanopticEvaluator(object): + def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): + self.gt_json = ann_file + self.gt_folder = ann_folder + if utils.is_main_process(): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + self.output_dir = output_dir + self.predictions = [] + + def update(self, predictions): + for p in predictions: + with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + self.predictions += predictions + + def synchronize_between_processes(self): + all_predictions = utils.all_gather(self.predictions) + merged_predictions = [] + for p in all_predictions: + merged_predictions += p + self.predictions = merged_predictions + + def summarize(self): + if utils.is_main_process(): + json_data = {"annotations": self.predictions} + predictions_json = os.path.join(self.output_dir, "predictions.json") + with open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) + return None diff --git a/datasets/samplers.py b/datasets/samplers.py new file mode 100644 index 0000000..14c0af2 --- /dev/null +++ b/datasets/samplers.py @@ -0,0 +1,139 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from codes in torch.utils.data.distributed +# ------------------------------------------------------------------------ + +import os +import math +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class NodeDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if local_rank is None: + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + if local_size is None: + local_size = int(os.environ.get('LOCAL_SIZE', 1)) + self.dataset = dataset + self.shuffle = shuffle + self.num_replicas = num_replicas + self.num_parts = local_size + self.rank = rank + self.local_rank = local_rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + indices = [i for i in indices if i % self.num_parts == self.local_rank] + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size_parts - len(indices))] + assert len(indices) == self.total_size_parts + + # subsample + indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/datasets/torchvision_datasets/__init__.py b/datasets/torchvision_datasets/__init__.py new file mode 100644 index 0000000..162303c --- /dev/null +++ b/datasets/torchvision_datasets/__init__.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from .coco import CocoDetection diff --git a/datasets/torchvision_datasets/coco.py b/datasets/torchvision_datasets/coco.py new file mode 100644 index 0000000..45b5f52 --- /dev/null +++ b/datasets/torchvision_datasets/coco.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from torchvision +# ------------------------------------------------------------------------ + +""" +Copy-Paste from torchvision, but add utility of caching images on memory +""" +from torchvision.datasets.vision import VisionDataset +from PIL import Image +import os +import os.path +import tqdm +from io import BytesIO + + +class CocoDetection(VisionDataset): + """`MS Coco Detection `_ Dataset. + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, + cache_mode=False, local_rank=0, local_size=1): + super(CocoDetection, self).__init__(root, transforms, transform, target_transform) + from pycocotools.coco import COCO + self.coco = COCO(annFile) + self.ids = list(sorted(self.coco.imgs.keys())) + self.cache_mode = cache_mode + self.local_rank = local_rank + self.local_size = local_size + if cache_mode: + self.cache = {} + self.cache_images() + + def cache_images(self): + self.cache = {} + for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): + if index % self.local_size != self.local_rank: + continue + path = self.coco.loadImgs(img_id)[0]['file_name'] + with open(os.path.join(self.root, path), 'rb') as f: + self.cache[path] = f.read() + + def get_image(self, path): + if self.cache_mode: + if path not in self.cache.keys(): + with open(os.path.join(self.root, path), 'rb') as f: + self.cache[path] = f.read() + return Image.open(BytesIO(self.cache[path])).convert('RGB') + return Image.open(os.path.join(self.root, path)).convert('RGB') + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + path = coco.loadImgs(img_id)[0]['file_name'] + + img = self.get_image(path) + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.ids) diff --git a/datasets/transforms.py b/datasets/transforms.py new file mode 100644 index 0000000..24fe477 --- /dev/null +++ b/datasets/transforms.py @@ -0,0 +1,294 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Transforms and data augmentation for both image + bbox. +""" +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from util.box_ops import box_xyxy_to_cxcywh +from util.misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target["masks"] = target["masks"][:, i : i + h, j : j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target["boxes"].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target["masks"].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1] + ) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size) + ) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target["masks"] = ( + interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] + > 0.5 + ) + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image[::-1]) + if "masks" in target: + target["masks"] = torch.nn.functional.pad( + target["masks"], (0, padding[0], 0, padding[1]) + ) + return padded_image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..7121d14 --- /dev/null +++ b/engine.py @@ -0,0 +1,285 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Train and eval functions used in main.py +""" +import math +import os +import sys +from typing import Iterable +import copy + +import wandb +import torch +import util.misc as utils +from datasets.coco_eval import CocoEvaluator +from datasets.panoptic_eval import PanopticEvaluator +from datasets.data_prefetcher import data_prefetcher + +scaler = torch.cuda.amp.GradScaler() + + +def train_hybrid(outputs, targets, k_one2many, criterion, lambda_one2many): + # one-to-one-loss + loss_dict = criterion(outputs, targets) + multi_targets = copy.deepcopy(targets) + # repeat the targets + for target in multi_targets: + target["boxes"] = target["boxes"].repeat(k_one2many, 1) + target["labels"] = target["labels"].repeat(k_one2many) + + outputs_one2many = dict() + outputs_one2many["pred_logits"] = outputs["pred_logits_one2many"] + outputs_one2many["pred_boxes"] = outputs["pred_boxes_one2many"] + outputs_one2many["aux_outputs"] = outputs["aux_outputs_one2many"] + + # one-to-many loss + loss_dict_one2many = criterion(outputs_one2many, multi_targets) + for key, value in loss_dict_one2many.items(): + if key + "_one2many" in loss_dict.keys(): + loss_dict[key + "_one2many"] += value * lambda_one2many + else: + loss_dict[key + "_one2many"] = value * lambda_one2many + return loss_dict + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + max_norm: float = 0, + k_one2many=1, + lambda_one2many=1.0, + use_wandb=False, + use_fp16=False, +): + model.train() + criterion.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter( + "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") + ) + metric_logger.add_meter( + "grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") + ) + header = "Epoch: [{}]".format(epoch) + print_freq = 10 + + prefetcher = data_prefetcher(data_loader, device, prefetch=True) + samples, targets = prefetcher.next() + + # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header): + with torch.cuda.amp.autocast() if use_fp16 else torch.cuda.amp.autocast( + enabled=False + ): + if use_fp16: + optimizer.zero_grad() + outputs = model(samples) + + if k_one2many > 0: + loss_dict = train_hybrid( + outputs, targets, k_one2many, criterion, lambda_one2many + ) + else: + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum( + loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict + ) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = { + f"{k}_unscaled": v for k, v in loss_dict_reduced.items() + } + loss_dict_reduced_scaled = { + k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() + if k in weight_dict + } + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + if use_fp16: + scaler.scale(losses).backward() + scaler.unscale_(optimizer) + else: + optimizer.zero_grad() + losses.backward() + if max_norm > 0: + grad_total_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm + ) + else: + grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) + + if use_fp16: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + metric_logger.update( + loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(grad_norm=grad_total_norm) + + samples, targets = prefetcher.next() + + if use_wandb: + try: + wandb.log(loss_dict) + except: + pass + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate( + model, + criterion, + postprocessors, + data_loader, + base_ds, + device, + output_dir, + use_wandb=False, +): + # disable the one-to-many branch queries + # save them frist + save_num_queries = model.module.num_queries + save_two_stage_num_proposals = model.module.transformer.two_stage_num_proposals + model.module.num_queries = model.module.num_queries_one2one + model.module.transformer.two_stage_num_proposals = model.module.num_queries + + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter( + "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}") + ) + header = "Test:" + + iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + panoptic_evaluator = None + if "panoptic" in postprocessors.keys(): + panoptic_evaluator = PanopticEvaluator( + data_loader.dataset.ann_file, + data_loader.dataset.ann_folder, + output_dir=os.path.join(output_dir, "panoptic_eval"), + ) + + for samples, targets in metric_logger.log_every(data_loader, 10, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = { + k: v * weight_dict[k] + for k, v in loss_dict_reduced.items() + if k in weight_dict + } + loss_dict_reduced_unscaled = { + f"{k}_unscaled": v for k, v in loss_dict_reduced.items() + } + metric_logger.update( + loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled, + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors["bbox"](outputs, orig_target_sizes) + if "segm" in postprocessors.keys(): + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + results = postprocessors["segm"]( + results, outputs, orig_target_sizes, target_sizes + ) + res = { + target["image_id"].item(): output + for target, output in zip(targets, results) + } + if coco_evaluator is not None: + coco_evaluator.update(res) + + if panoptic_evaluator is not None: + res_pano = postprocessors["panoptic"]( + outputs, target_sizes, orig_target_sizes + ) + for i, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[i]["image_id"] = image_id + res_pano[i]["file_name"] = file_name + + panoptic_evaluator.update(res_pano) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + panoptic_res = None + if panoptic_evaluator is not None: + panoptic_res = panoptic_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if "bbox" in postprocessors.keys(): + stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() + if "segm" in postprocessors.keys(): + stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() + if panoptic_res is not None: + stats["PQ_all"] = panoptic_res["All"] + stats["PQ_th"] = panoptic_res["Things"] + stats["PQ_st"] = panoptic_res["Stuff"] + if use_wandb: + try: + wandb.log({"AP": stats["coco_eval_bbox"][0]}) + wandb.log(stats) + except: + pass + + # recover the model parameters for next training epoch + model.module.num_queries = save_num_queries + model.module.transformer.two_stage_num_proposals = save_two_stage_num_proposals + return stats, coco_evaluator diff --git a/main.py b/main.py new file mode 100644 index 0000000..4602476 --- /dev/null +++ b/main.py @@ -0,0 +1,532 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +import argparse +import datetime +import json +import random +import time +from pathlib import Path +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +import datasets +import util.misc as utils +import datasets.samplers as samplers +from datasets import build_dataset, get_coco_api_from_dataset +from engine import evaluate, train_one_epoch +from models import build_model + + +def get_args_parser(): + parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False) + parser.add_argument("--lr", default=2e-4, type=float) + parser.add_argument( + "--lr_backbone_names", default=["backbone.0"], type=str, nargs="+" + ) + parser.add_argument("--lr_backbone", default=2e-5, type=float) + parser.add_argument( + "--lr_linear_proj_names", + default=["reference_points", "sampling_offsets"], + type=str, + nargs="+", + ) + parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float) + parser.add_argument("--batch_size", default=2, type=int) + parser.add_argument("--weight_decay", default=1e-4, type=float) + parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--lr_drop", default=40, type=int) + parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+") + parser.add_argument( + "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" + ) + + parser.add_argument("--sgd", action="store_true") + + # Variants of Deformable DETR + parser.add_argument("--with_box_refine", default=False, action="store_true") + parser.add_argument("--two_stage", default=False, action="store_true") + + # Model parameters + parser.add_argument( + "--frozen_weights", + type=str, + default=None, + help="Path to the pretrained model. If set, only the mask head will be trained", + ) + + # * Backbone + parser.add_argument( + "--backbone", + default="resnet50", + type=str, + help="Name of the convolutional backbone to use", + ) + parser.add_argument( + "--dilation", + action="store_true", + help="If true, we replace stride with dilation in the last convolutional block (DC5)", + ) + parser.add_argument( + "--position_embedding", + default="sine", + type=str, + choices=("sine", "learned"), + help="Type of positional embedding to use on top of the image features", + ) + parser.add_argument( + "--position_embedding_scale", + default=2 * np.pi, + type=float, + help="position / size * scale", + ) + parser.add_argument( + "--num_feature_levels", default=4, type=int, help="number of feature levels" + ) + # swin backbone + parser.add_argument( + "--pretrained_backbone_path", + default="./swin_tiny_patch4_window7_224.pkl", + type=str, + ) + parser.add_argument("--drop_path_rate", default=0.2, type=float) + + # * Transformer + parser.add_argument( + "--enc_layers", + default=6, + type=int, + help="Number of encoding layers in the transformer", + ) + parser.add_argument( + "--dec_layers", + default=6, + type=int, + help="Number of decoding layers in the transformer", + ) + parser.add_argument( + "--dim_feedforward", + default=2048, + type=int, + help="Intermediate size of the feedforward layers in the transformer blocks", + ) + parser.add_argument( + "--hidden_dim", + default=256, + type=int, + help="Size of the embeddings (dimension of the transformer)", + ) + parser.add_argument( + "--dropout", default=0.1, type=float, help="Dropout applied in the transformer" + ) + parser.add_argument( + "--nheads", + default=8, + type=int, + help="Number of attention heads inside the transformer's attentions", + ) + parser.add_argument( + "--num_queries_one2one", + default=300, + type=int, + help="Number of query slots for one-to-one matching", + ) + parser.add_argument( + "--num_queries_one2many", + default=0, + type=int, + help="Number of query slots for one-to-many matchining", + ) + parser.add_argument("--dec_n_points", default=4, type=int) + parser.add_argument("--enc_n_points", default=4, type=int) + # Deformable DETR tricks + parser.add_argument("--mixed_selection", action="store_true", default=False) + parser.add_argument("--look_forward_twice", action="store_true", default=False) + # hybrid branch + parser.add_argument("--k_one2many", default=5, type=int) + parser.add_argument("--lambda_one2many", default=1.0, type=float) + + # * Segmentation + parser.add_argument( + "--masks", + action="store_true", + help="Train segmentation head if the flag is provided", + ) + + # Loss + parser.add_argument( + "--no_aux_loss", + dest="aux_loss", + action="store_false", + help="Disables auxiliary decoding losses (loss at each layer)", + ) + + # * Matcher + parser.add_argument( + "--set_cost_class", + default=2, + type=float, + help="Class coefficient in the matching cost", + ) + parser.add_argument( + "--set_cost_bbox", + default=5, + type=float, + help="L1 box coefficient in the matching cost", + ) + parser.add_argument( + "--set_cost_giou", + default=2, + type=float, + help="giou box coefficient in the matching cost", + ) + + # * Loss coefficients + parser.add_argument("--mask_loss_coef", default=1, type=float) + parser.add_argument("--dice_loss_coef", default=1, type=float) + parser.add_argument("--cls_loss_coef", default=2, type=float) + parser.add_argument("--bbox_loss_coef", default=5, type=float) + parser.add_argument("--giou_loss_coef", default=2, type=float) + parser.add_argument("--focal_alpha", default=0.25, type=float) + + # dataset parameters + parser.add_argument("--dataset_file", default="coco") + parser.add_argument("--coco_path", default="./data/coco", type=str) + parser.add_argument("--coco_panoptic_path", type=str) + parser.add_argument("--remove_difficult", action="store_true") + + parser.add_argument( + "--output_dir", default="", help="path where to save, empty for no saving" + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=42, type=int) + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--num_workers", default=2, type=int) + parser.add_argument( + "--cache_mode", + default=False, + action="store_true", + help="whether to cache images on memory", + ) + + # * eval technologies + parser.add_argument("--eval", action="store_true") + # eval in training set + parser.add_argument("--eval_in_training_set", default=False, action="store_true") + # topk for eval + parser.add_argument("--topk", default=100, type=int) + + # * training technologies + parser.add_argument("--use_fp16", default=False, action="store_true") + parser.add_argument("--use_checkpoint", default=False, action="store_true") + + # * logging technologies + parser.add_argument("--use_wandb", action="store_true", default=False) + return parser + + +def main(args): + utils.init_distributed_mode(args) + print("git:\n {}\n".format(utils.get_sha())) + + if args.frozen_weights is not None: + assert args.masks, "Frozen training is meant for segmentation only" + print(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + model, criterion, postprocessors = build_model(args) + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of params:", n_parameters) + + dataset_train = build_dataset(image_set="train", args=args) + if not args.eval_in_training_set: + dataset_val = build_dataset( + image_set="val", args=args, eval_in_training_set=False, + ) + else: + print("eval in the training set") + dataset_val = build_dataset( + image_set="train", args=args, eval_in_training_set=True, + ) + + if args.distributed: + if args.cache_mode: + sampler_train = samplers.NodeDistributedSampler(dataset_train) + sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = samplers.DistributedSampler(dataset_train) + sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True + ) + + data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, + num_workers=args.num_workers, + pin_memory=True, + ) + data_loader_val = DataLoader( + dataset_val, + args.batch_size, + sampler=sampler_val, + drop_last=False, + collate_fn=utils.collate_fn, + num_workers=args.num_workers, + pin_memory=True, + ) + + # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"] + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + for n, p in model_without_ddp.named_parameters(): + print(n) + + param_dicts = [ + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and p.requires_grad + ], + "lr": args.lr, + }, + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad + ], + "lr": args.lr_backbone, + }, + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad + ], + "lr": args.lr * args.lr_linear_proj_mult, + }, + ] + if args.sgd: + optimizer = torch.optim.SGD( + param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay + ) + else: + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.dataset_file == "coco_panoptic": + # We also evaluate AP during panoptic training, on original coco DS + coco_val = datasets.coco.build("val", args) + base_ds = get_coco_api_from_dataset(coco_val) + else: + base_ds = get_coco_api_from_dataset(dataset_val) + + if args.frozen_weights is not None: + checkpoint = torch.load(args.frozen_weights, map_location="cpu") + model_without_ddp.detr.load_state_dict(checkpoint["model"]) + + output_dir = Path(args.output_dir) + if args.resume and os.path.exists(args.resume): + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + missing_keys, unexpected_keys = model_without_ddp.load_state_dict( + checkpoint["model"], strict=False + ) + unexpected_keys = [ + k + for k in unexpected_keys + if not (k.endswith("total_params") or k.endswith("total_ops")) + ] + if len(missing_keys) > 0: + print("Missing Keys: {}".format(missing_keys)) + if len(unexpected_keys) > 0: + print("Unexpected Keys: {}".format(unexpected_keys)) + if ( + not args.eval + and "optimizer" in checkpoint + and "lr_scheduler" in checkpoint + and "epoch" in checkpoint + ): + import copy + + p_groups = copy.deepcopy(optimizer.param_groups) + optimizer.load_state_dict(checkpoint["optimizer"]) + for pg, pg_old in zip(optimizer.param_groups, p_groups): + pg["lr"] = pg_old["lr"] + pg["initial_lr"] = pg_old["initial_lr"] + print(optimizer.param_groups) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). + args.override_resumed_lr_drop = True + if args.override_resumed_lr_drop: + print( + "Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler." + ) + lr_scheduler.step_size = args.lr_drop + lr_scheduler.base_lrs = list( + map(lambda group: group["initial_lr"], optimizer.param_groups) + ) + lr_scheduler.step(lr_scheduler.last_epoch) + args.start_epoch = checkpoint["epoch"] + 1 + # check the resumed model + if not args.eval: + test_stats, coco_evaluator = evaluate( + model, + criterion, + postprocessors, + data_loader_val, + base_ds, + device, + args.output_dir, + use_wandb=args.use_wandb, + ) + + if args.eval: + test_stats, coco_evaluator = evaluate( + model, + criterion, + postprocessors, + data_loader_val, + base_ds, + device, + args.output_dir, + use_wandb=args.use_wandb, + ) + if args.output_dir: + utils.save_on_master( + coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth" + ) + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( + model, + criterion, + data_loader_train, + optimizer, + device, + epoch, + args.clip_max_norm, + k_one2many=args.k_one2many, + lambda_one2many=args.lambda_one2many, + use_wandb=args.use_wandb, + use_fp16=args.use_fp16, + ) + lr_scheduler.step() + if args.output_dir: + checkpoint_paths = [output_dir / "checkpoint.pth"] + # extra checkpoint before LR drop and every 5 epochs + checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") + for checkpoint_path in checkpoint_paths: + utils.save_on_master( + { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + }, + checkpoint_path, + ) + + test_stats, coco_evaluator = evaluate( + model, + criterion, + postprocessors, + data_loader_val, + base_ds, + device, + args.output_dir, + use_wandb=args.use_wandb, + ) + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (output_dir / "eval").mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ["latest.pth"] + if epoch % 50 == 0: + filenames.append(f"{epoch:03}.pth") + for name in filenames: + torch.save( + coco_evaluator.coco_eval["bbox"].eval, + output_dir / "eval" / name, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Deformable DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/mmcv_custom/__init__.py b/mmcv_custom/__init__.py new file mode 100644 index 0000000..6812a6a --- /dev/null +++ b/mmcv_custom/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import load_checkpoint + +__all__ = ["load_checkpoint"] + diff --git a/mmcv_custom/checkpoint.py b/mmcv_custom/checkpoint.py new file mode 100644 index 0000000..326713b --- /dev/null +++ b/mmcv_custom/checkpoint.py @@ -0,0 +1,508 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +import mmcv +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.utils import mkdir_or_exist +from mmcv.runner import get_dist_info + +ENV_MMCV_HOME = "MMCV_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"), + ) + ) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=""): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, + prefix, + local_metadata, + True, + all_missing_keys, + unexpected_keys, + err_msg, + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key] + + if unexpected_keys: + err_msg.append( + "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n' + ) + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n' + ) + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert(0, "The model and loaded state dict do not match exactly\n") + err_msg = "\n".join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError("Please install pavi to load checkpoint from modelcloud.") + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get("LOCAL_RANK", rank)) + allowed_backends = ["ceph"] + if backend not in allowed_backends: + raise ValueError(f"Load from Backend {backend} is not supported.") + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f"torchvision.models.{name}") + if hasattr(_zoo, "model_urls"): + _urls = getattr(_zoo, "model_urls") + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json") + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, "open_mmlab.json") + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json") + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json") + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint["state_dict"] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith("backbone."): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith("modelzoo://"): + warnings.warn( + 'The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead' + ) + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith("torchvision://"): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith("open-mmlab://"): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn( + f"open-mmlab://{model_name} is deprecated in favor " + f"of open-mmlab://{deprecated_urls[model_name]}" + ) + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(("http://", "https://")): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f"{filename} is not a checkpoint file") + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith("mmcls://"): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(("http://", "https://")): + checkpoint = load_url_dist(filename) + elif filename.startswith("pavi://"): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith("s3://"): + checkpoint = load_fileclient_dist( + filename, backend="ceph", map_location=map_location + ) + else: + if not osp.isfile(filename): + raise IOError(f"{filename} is not a checkpoint file") + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError(f"No state_dict found in checkpoint file {filename}") + # get state_dict from checkpoint + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith("module."): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith("encoder"): + state_dict = { + k.replace("encoder.", ""): v + for k, v in state_dict.items() + if k.startswith("encoder.") + } + + # reshape absolute position embedding + if state_dict.get("absolute_pos_embed") is not None: + absolute_pos_embed = state_dict["absolute_pos_embed"] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict["absolute_pos_embed"] = absolute_pos_embed.view( + N2, H, W, C2 + ).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if "relative_position_bias_table" in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode="bicubic", + ) + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute( + 1, 0 + ) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix="", keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f"meta must be a dict or None, but got {type(meta)}") + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, "CLASSES") and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint["optimizer"] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint["optimizer"] = {} + for name, optim in optimizer.items(): + checkpoint["optimizer"][name] = optim.state_dict() + + if filename.startswith("pavi://"): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError("Please install pavi to load checkpoint from modelcloud.") + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, "wb") as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, "wb") as f: + torch.save(checkpoint, f) + f.flush() diff --git a/mmcv_custom/runner/__init__.py b/mmcv_custom/runner/__init__.py new file mode 100644 index 0000000..837dc4c --- /dev/null +++ b/mmcv_custom/runner/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Open-MMLab. All rights reserved. +from .checkpoint import save_checkpoint +from .epoch_based_runner import EpochBasedRunnerAmp + + +__all__ = ["EpochBasedRunnerAmp", "save_checkpoint"] + diff --git a/mmcv_custom/runner/checkpoint.py b/mmcv_custom/runner/checkpoint.py new file mode 100644 index 0000000..ae02b4a --- /dev/null +++ b/mmcv_custom/runner/checkpoint.py @@ -0,0 +1,81 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import os.path as osp +import time +from tempfile import TemporaryDirectory + +import torch +from torch.optim import Optimizer + +import mmcv +from mmcv.parallel import is_module_wrapper +from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict + +try: + import apex +except: + print("apex is not installed") + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 4 fields: ``meta``, ``state_dict`` and + ``optimizer``, ``amp``. By default ``meta`` will contain version + and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f"meta must be a dict or None, but got {type(meta)}") + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, "CLASSES") and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint["optimizer"] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint["optimizer"] = {} + for name, optim in optimizer.items(): + checkpoint["optimizer"][name] = optim.state_dict() + + # save amp state dict in the checkpoint + checkpoint["amp"] = apex.amp.state_dict() + + if filename.startswith("pavi://"): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError("Please install pavi to load checkpoint from modelcloud.") + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, "wb") as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, "wb") as f: + torch.save(checkpoint, f) + f.flush() diff --git a/mmcv_custom/runner/epoch_based_runner.py b/mmcv_custom/runner/epoch_based_runner.py new file mode 100644 index 0000000..9d504ad --- /dev/null +++ b/mmcv_custom/runner/epoch_based_runner.py @@ -0,0 +1,103 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import os.path as osp +import platform +import shutil + +import torch +from torch.optim import Optimizer + +import mmcv +from mmcv.runner import RUNNERS, EpochBasedRunner +from .checkpoint import save_checkpoint + +try: + import apex +except: + print("apex is not installed") + + +@RUNNERS.register_module() +class EpochBasedRunnerAmp(EpochBasedRunner): + """Epoch-based Runner with AMP support. + + This runner train models epoch by epoch. + """ + + def save_checkpoint( + self, + out_dir, + filename_tmpl="epoch_{}.pth", + save_optimizer=True, + meta=None, + create_symlink=True, + ): + """Save the checkpoint. + + Args: + out_dir (str): The directory that checkpoints are saved. + filename_tmpl (str, optional): The checkpoint filename template, + which contains a placeholder for the epoch number. + Defaults to 'epoch_{}.pth'. + save_optimizer (bool, optional): Whether to save the optimizer to + the checkpoint. Defaults to True. + meta (dict, optional): The meta information to be saved in the + checkpoint. Defaults to None. + create_symlink (bool, optional): Whether to create a symlink + "latest.pth" to point to the latest checkpoint. + Defaults to True. + """ + if meta is None: + meta = dict(epoch=self.epoch + 1, iter=self.iter) + elif isinstance(meta, dict): + meta.update(epoch=self.epoch + 1, iter=self.iter) + else: + raise TypeError(f"meta should be a dict or None, but got {type(meta)}") + if self.meta is not None: + meta.update(self.meta) + + filename = filename_tmpl.format(self.epoch + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, "latest.pth") + if platform.system() != "Windows": + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + def resume(self, checkpoint, resume_optimizer=True, map_location="default"): + if map_location == "default": + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id), + ) + else: + checkpoint = self.load_checkpoint(checkpoint) + else: + checkpoint = self.load_checkpoint(checkpoint, map_location=map_location) + + self._epoch = checkpoint["meta"]["epoch"] + self._iter = checkpoint["meta"]["iter"] + if "optimizer" in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint["optimizer"]) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict(checkpoint["optimizer"][k]) + else: + raise TypeError( + "Optimizer should be dict or torch.optim.Optimizer " + f"but got {type(self.optimizer)}" + ) + + if "amp" in checkpoint: + apex.amp.load_state_dict(checkpoint["amp"]) + self.logger.info("load amp state dict") + + self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter) + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..9a59c33 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +from .deformable_detr import build + + +def build_model(args): + return build(args) + diff --git a/models/backbone.py b/models/backbone.py new file mode 100644 index 0000000..98aa9e7 --- /dev/null +++ b/models/backbone.py @@ -0,0 +1,269 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding +from .swin_transformer import SwinTransformer + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + def __init__( + self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool + ): + super().__init__() + for name, parameter in backbone.named_parameters(): + if ( + not train_backbone + or "layer2" not in name + and "layer3" not in name + and "layer4" not in name + ): + parameter.requires_grad_(False) + if return_interm_layers: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + return_layers = {"layer4": "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__( + self, + name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool, + ): + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=norm_layer, + ) + assert name not in ("resnet18", "resnet34"), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + + +class TransformerBackbone(nn.Module): + def __init__( + self, backbone: str, train_backbone: bool, return_interm_layers: bool, args + ): + super().__init__() + out_indices = (1, 2, 3) + if backbone == "swin_tiny": + backbone = SwinTransformer( + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + ape=False, + drop_path_rate=args.drop_path_rate, + patch_norm=True, + use_checkpoint=True, + out_indices=out_indices, + ) + embed_dim = 96 + backbone.init_weights(args.pretrained_backbone_path) + elif backbone == "swin_small": + backbone = SwinTransformer( + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + ape=False, + drop_path_rate=args.drop_path_rate, + patch_norm=True, + use_checkpoint=True, + out_indices=out_indices, + ) + embed_dim = 96 + backbone.init_weights(args.pretrained_backbone_path) + elif backbone == "swin_large": + backbone = SwinTransformer( + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + ape=False, + drop_path_rate=args.drop_path_rate, + patch_norm=True, + use_checkpoint=True, + out_indices=out_indices, + ) + embed_dim = 192 + backbone.init_weights(args.pretrained_backbone_path) + elif backbone == "swin_large_window12": + backbone = SwinTransformer( + pretrain_img_size=384, + embed_dim=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + ape=False, + drop_path_rate=args.drop_path_rate, + patch_norm=True, + use_checkpoint=True, + out_indices=out_indices, + ) + embed_dim = 192 + backbone.init_weights(args.pretrained_backbone_path) + else: + raise NotImplementedError + + for name, parameter in backbone.named_parameters(): + # TODO: freeze some layers? + if not train_backbone: + parameter.requires_grad_(False) + + if return_interm_layers: + + self.strides = [8, 16, 32] + self.num_channels = [ + embed_dim * 2, + embed_dim * 4, + embed_dim * 8, + ] + else: + self.strides = [32] + self.num_channels = [embed_dim * 8] + + self.body = backbone + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks or (args.num_feature_levels > 1) + if "resnet" in args.backbone: + backbone = Backbone( + args.backbone, train_backbone, return_interm_layers, args.dilation, + ) + else: + backbone = TransformerBackbone( + args.backbone, train_backbone, return_interm_layers, args + ) + model = Joiner(backbone, position_embedding) + return model diff --git a/models/deformable_detr.py b/models/deformable_detr.py new file mode 100644 index 0000000..974d57f --- /dev/null +++ b/models/deformable_detr.py @@ -0,0 +1,656 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn +import math + +from util import box_ops +from util.misc import ( + NestedTensor, + nested_tensor_from_tensor_list, + accuracy, + get_world_size, + interpolate, + is_dist_avail_and_initialized, + inverse_sigmoid, +) + +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import ( + DETRsegm, + PostProcessPanoptic, + PostProcessSegm, + dice_loss, + sigmoid_focal_loss, +) +from .deformable_transformer import build_deforamble_transformer +import copy + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(nn.Module): + """ This is the Deformable DETR module that performs object detection """ + + def __init__( + self, + backbone, + transformer, + num_classes, + num_feature_levels, + aux_loss=True, + with_box_refine=False, + two_stage=False, + num_queries_one2one=300, + num_queries_one2many=0, + mixed_selection=False, + ): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + num_queries_one2one: number of object queries for one-to-one matching part + num_queries_one2many: number of object queries for one-to-many matching part + mixed_selection: a trick for Deformable DETR two stage + + """ + super().__init__() + num_queries = num_queries_one2one + num_queries_one2many + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim * 2) + elif mixed_selection: + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d( + in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 + ), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ] + ) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = ( + (transformer.decoder.num_layers + 1) + if two_stage + else transformer.decoder.num_layers + ) + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList( + [self.class_embed for _ in range(num_pred)] + ) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + self.num_queries_one2one = num_queries_one2one + self.mixed_selection = mixed_selection + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( + torch.bool + )[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage or self.mixed_selection: + query_embeds = self.query_embed.weight[0 : self.num_queries, :] + + # make attn mask + """ attention mask to prevent information leakage + """ + self_attn_mask = ( + torch.zeros([self.num_queries, self.num_queries,]).bool().to(src.device) + ) + self_attn_mask[self.num_queries_one2one :, 0 : self.num_queries_one2one,] = True + self_attn_mask[0 : self.num_queries_one2one, self.num_queries_one2one :,] = True + + ( + hs, + init_reference, + inter_references, + enc_outputs_class, + enc_outputs_coord_unact, + ) = self.transformer(srcs, masks, pos, query_embeds, self_attn_mask) + + outputs_classes_one2one = [] + outputs_coords_one2one = [] + outputs_classes_one2many = [] + outputs_coords_one2many = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + + outputs_classes_one2one.append( + outputs_class[:, 0 : self.num_queries_one2one] + ) + outputs_classes_one2many.append( + outputs_class[:, self.num_queries_one2one :] + ) + outputs_coords_one2one.append( + outputs_coord[:, 0 : self.num_queries_one2one] + ) + outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one :]) + outputs_classes_one2one = torch.stack(outputs_classes_one2one) + outputs_coords_one2one = torch.stack(outputs_coords_one2one) + outputs_classes_one2many = torch.stack(outputs_classes_one2many) + outputs_coords_one2many = torch.stack(outputs_coords_one2many) + + out = { + "pred_logits": outputs_classes_one2one[-1], + "pred_boxes": outputs_coords_one2one[-1], + "pred_logits_one2many": outputs_classes_one2many[-1], + "pred_boxes_one2many": outputs_coords_one2many[-1], + } + if self.aux_loss: + out["aux_outputs"] = self._set_aux_loss( + outputs_classes_one2one, outputs_coords_one2one + ) + out["aux_outputs_one2many"] = self._set_aux_loss( + outputs_classes_one2many, outputs_coords_one2many + ) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out["enc_outputs"] = { + "pred_logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord, + } + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat( + [t["labels"][J] for t, (_, J) in zip(targets, indices)] + ) + target_classes = torch.full( + src_logits.shape[:2], + self.num_classes, + dtype=torch.int64, + device=src_logits.device, + ) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros( + [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, + layout=src_logits.layout, + device=src_logits.device, + ) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = ( + sigmoid_focal_loss( + src_logits, + target_classes_onehot, + num_boxes, + alpha=self.focal_alpha, + gamma=2, + ) + * src_logits.shape[1] + ) + losses = {"loss_ce": loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs["pred_logits"] + device = pred_logits.device + tgt_lengths = torch.as_tensor( + [len(v["labels"]) for v in targets], device=device + ) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat( + [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 + ) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes), + ) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list( + [t["masks"] for t in targets] + ).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate( + src_masks[:, None], + size=target_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] + ) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] + ) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + "masks": self.loss_masks, + } + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = { + k: v + for k, v in outputs.items() + if k != "aux_outputs" and k != "enc_outputs" + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor( + [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update( + self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs) + ) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == "labels": + # Logging is enabled only for the last layer + kwargs["log"] = False + l_dict = self.get_loss( + loss, aux_outputs, targets, indices, num_boxes, **kwargs + ) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt["labels"] = torch.zeros_like(bt["labels"]) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == "labels": + # Logging is enabled only for the last layer + kwargs["log"] = False + l_dict = self.get_loss( + loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs + ) + l_dict = {k + f"_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, topk=100): + super().__init__() + self.topk = topk + print("topk for eval:", self.topk) + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), self.topk, dim=1 + ) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [ + {"scores": s, "labels": l, "boxes": b} + for s, l, b in zip(scores, labels, boxes) + ] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + num_classes = 20 if args.dataset_file != "coco" else 91 + if args.dataset_file == "coco_panoptic": + num_classes = 250 + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_deforamble_transformer(args) + model = DeformableDETR( + backbone, + transformer, + num_classes=num_classes, + num_feature_levels=args.num_feature_levels, + aux_loss=args.aux_loss, + with_box_refine=args.with_box_refine, + two_stage=args.two_stage, + num_queries_one2one=args.num_queries_one2one, + num_queries_one2many=args.num_queries_one2many, + mixed_selection=args.mixed_selection, + ) + if args.masks: + model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) + matcher = build_matcher(args) + weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} + weight_dict["loss_giou"] = args.giou_loss_coef + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + new_dict = dict() + for key, value in weight_dict.items(): + new_dict[key] = value + new_dict[key + "_one2many"] = value + weight_dict = new_dict + + losses = ["labels", "boxes", "cardinality"] + if args.masks: + losses += ["masks"] + # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 + criterion = SetCriterion( + num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha + ) + criterion.to(device) + postprocessors = {"bbox": PostProcess(topk=args.topk)} + if args.masks: + postprocessors["segm"] = PostProcessSegm() + if args.dataset_file == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic( + is_thing_map, threshold=0.85 + ) + + return model, criterion, postprocessors diff --git a/models/deformable_transformer.py b/models/deformable_transformer.py new file mode 100644 index 0000000..5d92afa --- /dev/null +++ b/models/deformable_transformer.py @@ -0,0 +1,632 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +import torch.utils.checkpoint as checkpoint +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from util.misc import inverse_sigmoid +from models.ops.modules import MSDeformAttn + + +class DeformableTransformer(nn.Module): + def __init__( + self, + d_model=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation="relu", + return_intermediate_dec=False, + num_feature_levels=4, + dec_n_points=4, + enc_n_points=4, + two_stage=False, + two_stage_num_proposals=300, + look_forward_twice=False, + mixed_selection=False, + use_checkpoint=False, + ): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + + encoder_layer = DeformableTransformerEncoderLayer( + d_model, + dim_feedforward, + dropout, + activation, + num_feature_levels, + nhead, + enc_n_points, + ) + self.encoder = DeformableTransformerEncoder( + encoder_layer, num_encoder_layers, use_checkpoint + ) + + decoder_layer = DeformableTransformerDecoderLayer( + d_model, + dim_feedforward, + dropout, + activation, + num_feature_levels, + nhead, + dec_n_points, + ) + self.decoder = DeformableTransformerDecoder( + decoder_layer, + num_decoder_layers, + return_intermediate_dec, + look_forward_twice, + use_checkpoint, + ) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + self.reference_points = nn.Linear(d_model, 2) + + self.mixed_selection = mixed_selection + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.0) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device + ) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack( + (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4 + ).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view( + N_, H_, W_, 1 + ) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H_ - 1, H_, dtype=torch.float32, device=memory.device + ), + torch.linspace( + 0, W_ - 1, W_, dtype=torch.float32, device=memory.device + ), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view( + N_, 1, 1, 2 + ) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ( + (output_proposals > 0.01) & (output_proposals < 0.99) + ).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float("inf") + ) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float("inf") + ) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0) + ) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, srcs, masks, pos_embeds, query_embed=None, self_attn_mask=None): + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device + ) + level_start_index = torch.cat( + (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) + ) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder( + src_flatten, + spatial_shapes, + level_start_index, + valid_ratios, + lvl_pos_embed_flatten, + mask_flatten, + ) + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers]( + output_memory + ) + enc_outputs_coord_unact = ( + self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + + output_proposals + ) + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)) + ) + + if not self.mixed_selection: + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + # query_embed here is the content embed for deformable DETR + tgt = query_embed.unsqueeze(0).expand(bs, -1, -1) + query_embed, _ = torch.split(pos_trans_out, c, dim=2) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + init_reference_out = reference_points + + # decoder + hs, inter_references = self.decoder( + tgt, + reference_points, + memory, + spatial_shapes, + level_start_index, + valid_ratios, + query_embed, + mask_flatten, + self_attn_mask, + ) + + inter_references_out = inter_references + if self.two_stage: + return ( + hs, + init_reference_out, + inter_references_out, + enc_outputs_class, + enc_outputs_coord_unact, + ) + return hs, init_reference_out, inter_references_out, None, None + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + ): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + self, + src, + pos, + reference_points, + spatial_shapes, + level_start_index, + padding_mask=None, + ): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, pos), + reference_points, + src, + spatial_shapes, + level_start_index, + padding_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, use_checkpoint=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.use_checkpoint = use_checkpoint + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + self, + src, + spatial_shapes, + level_start_index, + valid_ratios, + pos=None, + padding_mask=None, + ): + output = src + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=src.device + ) + for _, layer in enumerate(self.layers): + if self.use_checkpoint: + output = checkpoint.checkpoint( + layer, + output, + pos, + reference_points, + spatial_shapes, + level_start_index, + padding_mask, + ) + else: + output = layer( + output, + pos, + reference_points, + spatial_shapes, + level_start_index, + padding_mask, + ) + + return output + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation="relu", + n_levels=4, + n_heads=8, + n_points=4, + ): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + self, + tgt, + query_pos, + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask=None, + self_attn_mask=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), + k.transpose(0, 1), + tgt.transpose(0, 1), + attn_mask=self_attn_mask, + )[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + def __init__( + self, + decoder_layer, + num_layers, + return_intermediate=False, + look_forward_twice=False, + use_checkpoint=False, + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + self.look_forward_twice = look_forward_twice + self.use_checkpoint = use_checkpoint + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + self, + tgt, + reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None, + self_attn_mask=None, + ): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = ( + reference_points[:, :, None] * src_valid_ratios[:, None] + ) + if self.use_checkpoint: + output = checkpoint.checkpoint( + layer, + output, + query_pos, + reference_points_input, + src, + src_spatial_shapes, + src_level_start_index, + src_padding_mask, + self_attn_mask, + ) + else: + output = layer( + output, + query_pos, + reference_points_input, + src, + src_spatial_shapes, + src_level_start_index, + src_padding_mask, + self_attn_mask, + ) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( + reference_points + ) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append( + new_reference_points + if self.look_forward_twice + else reference_points + ) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(args): + return DeformableTransformer( + d_model=args.hidden_dim, + nhead=args.nheads, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=args.num_feature_levels, + dec_n_points=args.dec_n_points, + enc_n_points=args.enc_n_points, + two_stage=args.two_stage, + two_stage_num_proposals=args.num_queries_one2one + args.num_queries_one2many, + mixed_selection=args.mixed_selection, + look_forward_twice=args.look_forward_twice, + use_checkpoint=args.use_checkpoint, + ) + diff --git a/models/matcher.py b/models/matcher.py new file mode 100644 index 0000000..7a4df7c --- /dev/null +++ b/models/matcher.py @@ -0,0 +1,124 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__( + self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1 + ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert ( + cost_class != 0 or cost_bbox != 0 or cost_giou != 0 + ), "all costs cant be 0" + + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + with torch.no_grad(): + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + out_bbox = outputs["pred_boxes"].flatten( + 0, 1 + ) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = ( + (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + ) + pos_cost_class = ( + alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + ) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [ + linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) + ] + return [ + ( + torch.as_tensor(i, dtype=torch.int64), + torch.as_tensor(j, dtype=torch.int64), + ) + for i, j in indices + ] + + +def build_matcher(args): + return HungarianMatcher( + cost_class=args.set_cost_class, + cost_bbox=args.set_cost_bbox, + cost_giou=args.set_cost_giou, + ) diff --git a/models/ops/functions/__init__.py b/models/ops/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/models/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/models/ops/functions/ms_deform_attn_func.py b/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..7cb2e44 --- /dev/null +++ b/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step, + ) + ctx.save_for_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + @staticmethod + @once_differentiable + @torch.cuda.amp.custom_bwd + def backward(ctx, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch( + value, value_spatial_shapes, sampling_locations, attention_weights +): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = ( + value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + ) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape( + N_ * M_, 1, Lq_, L_ * P_ + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(N_, M_ * D_, Lq_) + ) + return output.transpose(1, 2).contiguous() diff --git a/models/ops/make.sh b/models/ops/make.sh new file mode 100644 index 0000000..106b685 --- /dev/null +++ b/models/ops/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +python setup.py build install diff --git a/models/ops/modules/__init__.py b/models/ops/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/models/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/models/ops/modules/ms_deform_attn.py b/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..2dd7501 --- /dev/null +++ b/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,162 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)) + ) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + "d_model must be divisible by n_heads, but got {} and {}".format( + d_model, n_heads + ) + ) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * ( + 2.0 * math.pi / self.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points + ) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.n_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/models/ops/setup.py b/models/ops/setup.py new file mode 100644 index 0000000..a0131bc --- /dev/null +++ b/models/ops/setup.py @@ -0,0 +1,71 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/models/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..e1bf854 --- /dev/null +++ b/models/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.h b/models/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..81b7b58 --- /dev/null +++ b/models/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.cu b/models/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..d6d5836 --- /dev/null +++ b/models/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.h b/models/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/models/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..6bc2acb --- /dev/null +++ b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/models/ops/src/ms_deform_attn.h b/models/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..ac0ef2e --- /dev/null +++ b/models/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/models/ops/src/vision.cpp b/models/ops/src/vision.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/models/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/models/ops/test.py b/models/ops/test.py new file mode 100644 index 0000000..8dbf6d5 --- /dev/null +++ b/models/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/models/position_encoding.py b/models/position_encoding.py new file mode 100644 index 0000000..82c7cba --- /dev/null +++ b/models/position_encoding.py @@ -0,0 +1,113 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/models/segmentation.py b/models/segmentation.py new file mode 100644 index 0000000..0100e44 --- /dev/null +++ b/models/segmentation.py @@ -0,0 +1,427 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +This file provides the definition of the convolutional heads used to predict masks, as well as the losses +""" +import io +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +import util.box_ops as box_ops +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + + +class DETRsegm(nn.Module): + def __init__(self, detr, freeze_detr=False): + super().__init__() + self.detr = detr + + if freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead + self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) + self.mask_head = MaskHeadSmallConv( + hidden_dim + nheads, [1024, 512, 256], hidden_dim + ) + + def forward(self, samples: NestedTensor): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.detr.backbone(samples) + + bs = features[-1].tensors.shape[0] + + src, mask = features[-1].decompose() + src_proj = self.detr.input_proj(src) + hs, memory = self.detr.transformer( + src_proj, mask, self.detr.query_embed.weight, pos[-1] + ) + + outputs_class = self.detr.class_embed(hs) + outputs_coord = self.detr.bbox_embed(hs).sigmoid() + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.detr.aux_loss: + out["aux_outputs"] = [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head( + src_proj, + bbox_mask, + [features[2].tensors, features[1].tensors, features[0].tensors], + ) + outputs_seg_masks = seg_masks.view( + bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1] + ) + + out["pred_masks"] = outputs_seg_masks + return out + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + inter_dims = [ + dim, + context_dim // 2, + context_dim // 4, + context_dim // 8, + context_dim // 16, + context_dim // 64, + ] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, bbox_mask, fpns): + def expand(tensor, length): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask=None): + q = self.q_linear(q) + k = F.conv2d( + k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias + ) + qh = q.view( + q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads + ) + kh = k.view( + k.shape[0], + self.num_heads, + self.hidden_dim // self.num_heads, + k.shape[-2], + k.shape[-1], + ) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss( + inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2 +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes, max_target_sizes): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate( + outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False + ) + outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate( + zip(outputs_masks, max_target_sizes, orig_target_sizes) + ): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, in the format expected by the + coco panoptic API """ + + def __init__(self, is_thing_map, threshold=0.85): + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether + the class is a thing (True) or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = ( + outputs["pred_logits"], + outputs["pred_masks"], + outputs["pred_boxes"], + ) + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & ( + scores > self.threshold + ) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate( + cur_masks[None], to_tuple(size), mode="bilinear" + ).squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize( + size=(final_w, final_h), resample=Image.NEAREST + ) + + np_seg_img = ( + torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) + .view(final_h, final_w, 3) + .numpy() + ) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_classes)], + dtype=torch.bool, + device=keep.device, + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append( + { + "id": i, + "isthing": self.is_thing_map[cat], + "category_id": cat, + "area": a, + } + ) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = { + "png_string": out.getvalue(), + "segments_info": segments_info, + } + preds.append(predictions) + return preds diff --git a/models/swin_transformer.py b/models/swin_transformer.py new file mode 100644 index 0000000..2856682 --- /dev/null +++ b/models/swin_transformer.py @@ -0,0 +1,743 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from mmcv_custom import load_checkpoint +from mmdet.utils import get_root_logger + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + self.drop_path_rate = drop_path_rate + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError("pretrained must be a str or None") + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = ( + x_out.view(-1, H, W, self.num_features[i]) + .permute(0, 3, 1, 2) + .contiguous() + ) + outs[str(i)] = out + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..db93a67 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +pycocotools +tqdm +cython +scipy +wandb diff --git a/tools/launch.py b/tools/launch.py new file mode 100644 index 0000000..7ded7fc --- /dev/null +++ b/tools/launch.py @@ -0,0 +1,216 @@ +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py +# -------------------------------------------------------------------------------------------------------------------------- + +r""" +`torch.distributed.launch` is a module that spawns up multiple distributed +training processes on each of the training nodes. +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be benefitial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc_per_node``). If used for GPU training, this number needs to be less +or euqal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. +**How to use this module:** +1. Single-Node multi-process distributed training +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) +2. Multi-Node multi-process distributed training: (e.g. two nodes) +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +Node 2: +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +3. To look up what optional arguments this module offers: +:: + >>> python -m torch.distributed.launch --help +**Important Notices:** +1. This utilty and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. +2. In your training program, you must parse the command-line argument: +``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: +Parsing the local_rank argument +:: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local_rank", type=int) + >>> args = parser.parse_args() +Set your device to local rank using either +:: + >>> torch.cuda.set_device(arg.local_rank) # before your code runs +or +:: + >>> with torch.cuda.device(arg.local_rank): + >>> # your code to run +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. You need to make sure that +the init_method uses ``env://``, which is the only supported ``init_method`` +by this module. +:: + torch.distributed.init_process_group(backend='YOUR BACKEND', + init_method='env://') +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. +:: + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[arg.local_rank], + output_device=arg.local_rank) +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use_env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local_rank`` when you specify this flag. +.. warning:: + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. +""" + + +import sys +import subprocess +import os +import socket +from argparse import ArgumentParser, REMAINDER + +import torch + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser( + description="PyTorch distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes" + ) + + # Optional arguments for the launch helper + parser.add_argument( + "--nnodes", + type=int, + default=1, + help="The number of nodes to use for distributed " "training", + ) + parser.add_argument( + "--node_rank", + type=int, + default=0, + help="The rank of the node for multi-node distributed " "training", + ) + parser.add_argument( + "--nproc_per_node", + type=int, + default=1, + help="The number of processes to launch on each node, " + "for GPU training, this is recommended to be set " + "to the number of GPUs in your system so that " + "each process can be bound to a single GPU.", + ) + parser.add_argument( + "--master_addr", + default="127.0.0.1", + type=str, + help="Master node (rank 0)'s address, should be either " + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1", + ) + parser.add_argument( + "--master_port", + default=29500, + type=int, + help="Master node (rank 0)'s free port that needs to " + "be used for communciation during distributed " + "training", + ) + + # positional + parser.add_argument( + "training_script", + type=str, + help="The full path to the single GPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script", + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + return parser.parse_args() + + +def main(): + args = parse_args() + + # world size in terms of number of processes + dist_world_size = args.nproc_per_node * args.nnodes + + # set PyTorch distributed related environmental variables + current_env = os.environ.copy() + current_env["MASTER_ADDR"] = args.master_addr + current_env["MASTER_PORT"] = str(args.master_port) + current_env["WORLD_SIZE"] = str(dist_world_size) + + processes = [] + + for local_rank in range(0, args.nproc_per_node): + # each process's rank + dist_rank = args.nproc_per_node * args.node_rank + local_rank + current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) + + cmd = [args.training_script] + args.training_script_args + + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + + for process in processes: + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError( + returncode=process.returncode, cmd=process.args + ) + + +if __name__ == "__main__": + main() diff --git a/tools/run_dist_launch.sh b/tools/run_dist_launch.sh new file mode 100644 index 0000000..6c44d3a --- /dev/null +++ b/tools/run_dist_launch.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +set -x + +GPUS=$1 +RUN_COMMAND=${@:2} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +MASTER_PORT=${MASTER_PORT:-"29500"} +NODE_RANK=${NODE_RANK:-0} + +let "NNODES=GPUS/GPUS_PER_NODE" + +python ./tools/launch.py \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + --nproc_per_node ${GPUS_PER_NODE} \ + ${RUN_COMMAND} diff --git a/tools/run_dist_slurm.sh b/tools/run_dist_slurm.sh new file mode 100644 index 0000000..bd73d0b --- /dev/null +++ b/tools/run_dist_slurm.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh +# -------------------------------------------------------------------------------------------------------------------------- + +set -x + +PARTITION=$1 +JOB_NAME=$2 +GPUS=$3 +RUN_COMMAND=${@:4} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +CPUS_PER_TASK=${CPUS_PER_TASK:-4} +SRUN_ARGS=${SRUN_ARGS:-""} + +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + ${RUN_COMMAND} + diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..4ebdc90 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ diff --git a/util/box_ops.py b/util/box_ops.py new file mode 100644 index 0000000..de09980 --- /dev/null +++ b/util/box_ops.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000..8d0fff7 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,514 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size,), dtype=torch.uint8, device="cuda" + ) + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking=False): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs): + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_size(): + if not is_dist_avail_and_initialized(): + return 1 + return int(os.environ["LOCAL_SIZE"]) + + +def get_local_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return int(os.environ["LOCAL_RANK"]) + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + args.dist_url = "env://" + os.environ["LOCAL_SIZE"] = str(torch.cuda.device_count()) + elif "SLURM_PROCID" in os.environ: + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + "scontrol show hostname {} | head -n1".format(node_list) + ) + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["RANK"] = str(proc_id) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["LOCAL_SIZE"] = str(num_gpus) + args.dist_url = "env://" + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + return torchvision.ops.misc.interpolate( + input, size, scale_factor, mode, align_corners + ) + + +def get_total_grad_norm(parameters, norm_type=2): + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + device = parameters[0].grad.device + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + diff --git a/util/plot_utils.py b/util/plot_utils.py new file mode 100644 index 0000000..759f34d --- /dev/null +++ b/util/plot_utils.py @@ -0,0 +1,111 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # verify valid dir(s) and that every item in list is Path object + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if dir.exists(): + continue + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs + + +