-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create tensor.expandshape instead of tensor.reshape when possible #3647
Comments
@MaheshRavishankar please feel free to edit the above issue, if we want to add any more details on this. |
I dont know if it is possible to do this in torch, but what we really want is
I dont know if there is already a canonicalization for |
@kumardeepakamd I assigned it to you for now. Please redirect. |
Hi @MaheshRavishankar @nirvedhmeshram, do we still want the TorchToLinalg lowering be modified? |
Yes! |
Working on this now. |
I've been able to a canonicalizer for func.func @torch_jit(%arg0: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[?,12,?,64],f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int64 = torch.constant.int 64
%int12 = torch.constant.int 12
%0 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %1, %0, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,768],f32>, !torch.list<int> -> !torch.vtensor<[?,?,12,64],f32>
%4 = torch.aten.transpose.int %3, %int1, %int2 : !torch.vtensor<[?,?,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12,?,64],f32>
return %4 : !torch.vtensor<[?,12,?,64],f32>
}
} But this still gets lowered to reshape: module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @torch_jit(%arg0: tensor<?x?x768xf32>) -> tensor<?x12x?x64xf32> {
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c12_i64 = arith.constant 12 : i64
%c64_i64 = arith.constant 64 : i64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c768_i64 = arith.constant 768 : i64
%dim = tensor.dim %arg0, %c1 : tensor<?x?x768xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
%1 = arith.index_cast %dim_0 : index to i64
%2 = arith.cmpi slt, %1, %c0_i64 : i64
%3 = arith.select %2, %c1_i64, %1 : i64
%4 = arith.extui %2 : i1 to i64
%5 = arith.muli %3, %0 : i64
%6 = arith.addi %4, %c1_i64 : i64
%7 = arith.cmpi slt, %0, %c0_i64 : i64
%8 = arith.select %7, %3, %5 : i64
%9 = arith.select %7, %6, %4 : i64
%10 = arith.muli %8, %c768_i64 : i64
%11 = arith.cmpi sle, %9, %c1_i64 : i64
cf.assert %11, "must have at most one inferred (negative) dimension"
%12 = arith.muli %1, %0 : i64
%13 = arith.muli %12, %c768_i64 : i64
%14 = arith.divsi %13, %10 : i64
%15 = arith.select %2, %14, %1 : i64
%16 = arith.select %7, %14, %0 : i64
%from_elements = tensor.from_elements %15, %16, %c12_i64, %c64_i64 : tensor<4xi64>
%reshape = tensor.reshape %arg0(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
%17 = arith.index_cast %15 : i64 to index
%18 = arith.index_cast %16 : i64 to index
%19 = tensor.empty(%17, %18) : tensor<?x12x?x64xf32>
%transposed = linalg.transpose ins(%reshape : tensor<?x?x12x64xf32>) outs(%19 : tensor<?x12x?x64xf32>) permutation = [0, 2, 1, 3]
return %transposed : tensor<?x12x?x64xf32>
}
} Is it necessary to get this to an expand shape op, or sufficient to have the static shape information attached? |
Maybe it would be better to convert this to an |
For the following starting IR
We get this output after torch-to-linalg conversion + cse
Here
Is a harder op for the backends to handle and is dropping semantic information that this particular reshape could just be a expand_shape
I think the IR we would want at this level is.
The text was updated successfully, but these errors were encountered: