Skip to content

Commit

Permalink
Fix conversion with rtmdet-inst, vit, conformer (#2453)
Browse files Browse the repository at this point in the history
* fix

* fix scaled_dot_product_attention
  • Loading branch information
RunningLeon authored Sep 22, 2023
1 parent 01a88be commit 0c21b16
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
16 changes: 10 additions & 6 deletions mmdeploy/mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,19 +610,19 @@ def multiclass_nms__torchscript(boxes: Tensor,
Use batched_nms from torchvision instead of custom nms.
"""
assert not output_index, 'output_index is not supported on this backend.'
# TODO: simplify inference for non-batch model
from torchvision.ops import batched_nms
batch_size = scores.shape[0]
num_boxes = scores.shape[1]
num_classes = scores.shape[2]
box_per_cls = len(boxes.shape) == 4
scores = torch.where(scores > score_threshold, scores, scores.new_zeros(1))

pre_topk_inds = None
# pre-topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
pre_topk_inds = topk_inds
batch_inds = torch.arange(batch_size).view(-1, 1).long()
boxes = boxes[batch_inds, topk_inds, ...]
scores = scores[batch_inds, topk_inds, :]
Expand All @@ -646,10 +646,14 @@ def multiclass_nms__torchscript(boxes: Tensor,

keeps = torch.cat(keeps)
scores = scores.permute(0, 2, 1)
dets, labels = _select_nms_index(
scores, boxes, keeps, batch_size, keep_top_k=keep_top_k)

return dets, labels
return _select_nms_index(
scores,
boxes,
keeps,
batch_size,
keep_top_k=keep_top_k,
pre_inds=pre_topk_inds,
output_index=output_index)


class AscendBatchNMSOp(torch.autograd.Function):
Expand Down
26 changes: 26 additions & 0 deletions mmdeploy/pytorch/functions/multi_head_attention_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@ def _scaled_dot_product_attention__tensorrt(q: Tensor,
**kwargs) -> Tuple[Tensor, Tensor]:
"""Rewrite for custom ops."""
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)


@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.scaled_dot_product_attention',
backend=Backend.DEFAULT.value)
def scaled_dot_product_attention__default(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
"""Rewrite to export to onnx on torch>=2.0.0."""
scale = scale or query.size(-1)**0.5
if is_causal and attn_mask is not None:
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))

attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value

0 comments on commit 0c21b16

Please sign in to comment.