diff --git a/thop/profile.py b/thop/profile.py index 6b15d27..8f1aece 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -203,6 +203,10 @@ def add_hooks(m: nn.Module): ) types_collection.add(m_type) + def remove_buffers(m: nn.Module): + m._buffers.pop("total_ops") + m._buffers.pop("total_params") + prev_training_status = model.training model.eval() @@ -239,8 +243,7 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int): for m, (op_handler, params_handler) in handler_collection.items(): op_handler.remove() params_handler.remove() - m._buffers.pop("total_ops") - m._buffers.pop("total_params") + model.apply(remove_buffers) if ret_layer_info: return total_ops, total_params, ret_dict