add missing memref cast fold pattern for dim op

- add missing canonicalization pattern to fold memref_cast + dim to
  dim (needed to propagate constant when folding a dynamic shape to
  a static one)

- also fix an outdated/inconsistent comment in StandardOps/Ops.td

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#126

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/126 from bondhugula:quickfix 4566e75e49685c532faffff91d64c5d83d4da524
PiperOrigin-RevId: 269020058
This commit is contained in:
Uday Bondhugula 2019-09-13 18:18:21 -07:00 committed by A. Unique TensorFlower
parent d780bdef20
commit 1e6a93b7ca
3 changed files with 17 additions and 6 deletions

View File

@ -556,6 +556,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def DivFOp : FloatArithmeticOp<"divf"> { def DivFOp : FloatArithmeticOp<"divf"> {
@ -580,9 +581,9 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
with the same type as the elements of the tensor or vector. The arity of with the same type as the elements of the tensor or vector. The arity of
indices matches the rank of the accessed value (i.e., if a tensor is of rank indices matches the rank of the accessed value (i.e., if a tensor is of rank
3, then 3 indices are required for the extract). The indices should all be 3, then 3 indices are required for the extract). The indices should all be
of affine_int type. For example: of index type. For example:
%0 = extract_element %0[%1, %2] : vector<4x4xi32> %3 = extract_element %0[%1, %2] : vector<4x4xi32>
}]; }];
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,

View File

@ -1381,6 +1381,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dim(memrefcast) -> dim
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DivISOp // DivISOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -256,21 +256,25 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK-LABEL: func @memref_cast_folding // CHECK-LABEL: func @memref_cast_folding
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 { func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
// CHECK-NOT: memref_cast
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32> %1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
// CHECK-NEXT: %c0 = constant 0 : index
%c0 = constant 0 : index %c0 = constant 0 : index
// CHECK-NOT: dim
%dim = dim %1, 0 : memref<? x f32>
// CHECK: affine.load %arg0[%c4 - 1]
affine.load %1[%dim - 1] : memref<?xf32>
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32> // CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
store %arg1, %1[%c0] : memref<?xf32> store %arg1, %1[%c0] : memref<?xf32>
// CHECK-NEXT: %0 = load %arg0[%c0] : memref<4xf32> // CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32>
%0 = load %1[%c0] : memref<?xf32> %0 = load %1[%c0] : memref<?xf32>
// CHECK-NEXT: dealloc %arg0 : memref<4xf32> // CHECK-NEXT: dealloc %arg0 : memref<4xf32>
dealloc %1: memref<?xf32> dealloc %1: memref<?xf32>
// CHECK-NEXT: return %0 // CHECK-NEXT: return %{{.*}}
return %0 : f32 return %0 : f32
} }