Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tensor.expandshape instead of tensor.reshape when possible #3647

Open
nirvedhmeshram opened this issue Aug 19, 2024 · 8 comments
Open
Assignees

Comments

@nirvedhmeshram
Copy link
Collaborator

For the following starting IR

  func.func @torch_jit(%105: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %int2 = torch.constant.int 2
  %int64 = torch.constant.int 64
  %int12 = torch.constant.int 12
  %132 = torch.aten.size.int %105, %int1 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
  %124 = torch.aten.size.int %105, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
  %136 = torch.prim.ListConstruct %124, %132, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %137 = torch.aten.view %105, %136 : !torch.vtensor<[?,?,768],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
  %138 = torch.aten.transpose.int %137, %int1, %int2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
  return %138 : !torch.vtensor<[?,?,?,?],f32>
  }

We get this output after torch-to-linalg conversion + cse

func.func @torch_jit(%arg0: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,768],f32> -> tensor<?x?x768xf32>
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%int64 = torch.constant.int 64
%int12 = torch.constant.int 12
%c3_i64 = arith.constant 3 : i64
%1 = arith.addi %c1_i64, %c3_i64 : i64
%2 = arith.cmpi sge, %c1_i64, %c0_i64 : i64
%3 = arith.select %2, %c1_i64, %1 : i64
%4 = arith.index_cast %3 : i64 to index
%dim = tensor.dim %0, %4 : tensor<?x?x768xf32>
%5 = arith.index_cast %dim : index to i64
%6 = torch_c.from_i64 %5
%7 = arith.addi %c0_i64, %c3_i64 : i64
%8 = arith.cmpi sge, %c0_i64, %c0_i64 : i64
%9 = arith.select %8, %c0_i64, %7 : i64
%10 = arith.index_cast %9 : i64 to index
%dim_0 = tensor.dim %0, %10 : tensor<?x?x768xf32>
%11 = arith.index_cast %dim_0 : index to i64
%12 = torch_c.from_i64 %11
%13 = torch_c.to_i64 %12
%14 = torch_c.to_i64 %6
%c12_i64 = arith.constant 12 : i64
%c64_i64 = arith.constant 64 : i64
%from_elements = tensor.from_elements %13, %14, %c12_i64, %c64_i64 : tensor<4xi64>
%reshape = tensor.reshape %0(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
%c0 = arith.constant 0 : index
%dim_1 = tensor.dim %reshape, %c0 : tensor<?x?x?x?xf32>
%c1 = arith.constant 1 : index
%dim_2 = tensor.dim %reshape, %c1 : tensor<?x?x?x?xf32>
%c2 = arith.constant 2 : index
%dim_3 = tensor.dim %reshape, %c2 : tensor<?x?x?x?xf32>
%c3 = arith.constant 3 : index
%dim_4 = tensor.dim %reshape, %c3 : tensor<?x?x?x?xf32>
%15 = tensor.empty(%dim_1, %dim_3, %dim_2, %dim_4) : tensor<?x?x?x?xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape : tensor<?x?x?x?xf32>) outs(%15 : tensor<?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
  linalg.yield %in : f32
} -> tensor<?x?x?x?xf32>
%cast = tensor.cast %16 : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
%17 = torch_c.from_builtin_tensor %cast : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
return %17 : !torch.vtensor<[?,?,?,?],f32>
}

Here

  %from_elements = tensor.from_elements %13, %14, %c12_i64, %c64_i64 : tensor<4xi64>
  %reshape = tensor.reshape %0(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>

Is a harder op for the backends to handle and is dropping semantic information that this particular reshape could just be a expand_shape

I think the IR we would want at this level is.

%reshape = tensor.expand_shape %a [[0], [1], [2, 3]] output_shape [%13, %14, %c12_i64, %c64_i64]
   : tensor<?x?x768xf32> into tensor<?x?x?x?xf32>
@nirvedhmeshram
Copy link
Collaborator Author

@MaheshRavishankar please feel free to edit the above issue, if we want to add any more details on this.

@MaheshRavishankar
Copy link
Contributor

I think the IR we would want at this level is.

%reshape = tensor.expand_shape %a [[0], [1], [2, 3]] output_shape [%13, %14, %c12_i64, %c64_i64]
   : tensor<?x?x768xf32> into tensor<?x?x?x?xf32>

I dont know if it is possible to do this in torch, but what we really want is

%reshape = tensor.expand_shape %a [[0], [1], [2, 3]] output_shape [%13, %14, 12, 64]
   : tensor<?x?x768xf32> into tensor<?x?x12x64xf32>

I dont know if there is already a canonicalization for tensor.expand_shape that converts from what Nirvedh added to this form. So if it is not possible from torch to directly generate this op with the static dims in the return type, then we will need a canonicalization to convert from what Nirvedh added to this form.

@MaheshRavishankar
Copy link
Contributor

@kumardeepakamd I assigned it to you for now. Please redirect.

@vivekkhandelwal1
Copy link
Collaborator

Hi @MaheshRavishankar @nirvedhmeshram, do we still want the TorchToLinalg lowering be modified?

@MaheshRavishankar
Copy link
Contributor

Yes!

@zjgarvey
Copy link
Collaborator

Working on this now.

@zjgarvey
Copy link
Collaborator

I've been able to a canonicalizer for AtenViewOp that converts the original IR to

  func.func @torch_jit(%arg0: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[?,12,?,64],f32> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int64 = torch.constant.int 64
    %int12 = torch.constant.int 12
    %0 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %1 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %2 = torch.prim.ListConstruct %1, %0, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,768],f32>, !torch.list<int> -> !torch.vtensor<[?,?,12,64],f32>
    %4 = torch.aten.transpose.int %3, %int1, %int2 : !torch.vtensor<[?,?,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12,?,64],f32>
    return %4 : !torch.vtensor<[?,12,?,64],f32>
  }
}

But this still gets lowered to reshape:

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<?x?x768xf32>) -> tensor<?x12x?x64xf32> {
    %c1_i64 = arith.constant 1 : i64
    %c0_i64 = arith.constant 0 : i64
    %c12_i64 = arith.constant 12 : i64
    %c64_i64 = arith.constant 64 : i64
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c768_i64 = arith.constant 768 : i64
    %dim = tensor.dim %arg0, %c1 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %1 = arith.index_cast %dim_0 : index to i64
    %2 = arith.cmpi slt, %1, %c0_i64 : i64
    %3 = arith.select %2, %c1_i64, %1 : i64
    %4 = arith.extui %2 : i1 to i64
    %5 = arith.muli %3, %0 : i64
    %6 = arith.addi %4, %c1_i64 : i64
    %7 = arith.cmpi slt, %0, %c0_i64 : i64
    %8 = arith.select %7, %3, %5 : i64
    %9 = arith.select %7, %6, %4 : i64
    %10 = arith.muli %8, %c768_i64 : i64
    %11 = arith.cmpi sle, %9, %c1_i64 : i64
    cf.assert %11, "must have at most one inferred (negative) dimension"
    %12 = arith.muli %1, %0 : i64
    %13 = arith.muli %12, %c768_i64 : i64
    %14 = arith.divsi %13, %10 : i64
    %15 = arith.select %2, %14, %1 : i64
    %16 = arith.select %7, %14, %0 : i64
    %from_elements = tensor.from_elements %15, %16, %c12_i64, %c64_i64 : tensor<4xi64>
    %reshape = tensor.reshape %arg0(%from_elements) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
    %17 = arith.index_cast %15 : i64 to index
    %18 = arith.index_cast %16 : i64 to index
    %19 = tensor.empty(%17, %18) : tensor<?x12x?x64xf32>
    %transposed = linalg.transpose ins(%reshape : tensor<?x?x12x64xf32>) outs(%19 : tensor<?x12x?x64xf32>) permutation = [0, 2, 1, 3] 
    return %transposed : tensor<?x12x?x64xf32>
  }
}

Is it necessary to get this to an expand shape op, or sufficient to have the static shape information attached?

@zjgarvey
Copy link
Collaborator

Maybe it would be better to convert this to an AtenUnflattenIntOp. I'm going to keep working on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants