forked from OSchip/llvm-project
add missing memref cast fold pattern for dim op
- add missing canonicalization pattern to fold memref_cast + dim to dim (needed to propagate constant when folding a dynamic shape to a static one) - also fix an outdated/inconsistent comment in StandardOps/Ops.td Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#126 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/126 from bondhugula:quickfix 4566e75e49685c532faffff91d64c5d83d4da524 PiperOrigin-RevId: 269020058
This commit is contained in:
parent
d780bdef20
commit
1e6a93b7ca
|
@ -556,6 +556,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def DivFOp : FloatArithmeticOp<"divf"> {
|
def DivFOp : FloatArithmeticOp<"divf"> {
|
||||||
|
@ -580,9 +581,9 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
||||||
with the same type as the elements of the tensor or vector. The arity of
|
with the same type as the elements of the tensor or vector. The arity of
|
||||||
indices matches the rank of the accessed value (i.e., if a tensor is of rank
|
indices matches the rank of the accessed value (i.e., if a tensor is of rank
|
||||||
3, then 3 indices are required for the extract). The indices should all be
|
3, then 3 indices are required for the extract). The indices should all be
|
||||||
of affine_int type. For example:
|
of index type. For example:
|
||||||
|
|
||||||
%0 = extract_element %0[%1, %2] : vector<4x4xi32>
|
%3 = extract_element %0[%1, %2] : vector<4x4xi32>
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
|
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
|
||||||
|
|
|
@ -1381,6 +1381,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
/// dim(memrefcast) -> dim
|
||||||
|
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DivISOp
|
// DivISOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -256,21 +256,25 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @memref_cast_folding
|
// CHECK-LABEL: func @memref_cast_folding
|
||||||
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
|
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
|
||||||
|
// CHECK-NOT: memref_cast
|
||||||
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
|
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %c0 = constant 0 : index
|
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
|
// CHECK-NOT: dim
|
||||||
|
%dim = dim %1, 0 : memref<? x f32>
|
||||||
|
|
||||||
|
// CHECK: affine.load %arg0[%c4 - 1]
|
||||||
|
affine.load %1[%dim - 1] : memref<?xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
|
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
|
||||||
store %arg1, %1[%c0] : memref<?xf32>
|
store %arg1, %1[%c0] : memref<?xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: %0 = load %arg0[%c0] : memref<4xf32>
|
// CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32>
|
||||||
%0 = load %1[%c0] : memref<?xf32>
|
%0 = load %1[%c0] : memref<?xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: dealloc %arg0 : memref<4xf32>
|
// CHECK-NEXT: dealloc %arg0 : memref<4xf32>
|
||||||
dealloc %1: memref<?xf32>
|
dealloc %1: memref<?xf32>
|
||||||
|
|
||||||
// CHECK-NEXT: return %0
|
// CHECK-NEXT: return %{{.*}}
|
||||||
return %0 : f32
|
return %0 : f32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue