forked from OSchip/llvm-project
[MLIR][Linalg] Generate the right type of load/store when lowering max/min pooling ops
While lowering min/max pooling ops to loops, generate the right kind of load/stores (std or affine) instead of always generating std load/stores. Differential Revision: https://reviews.llvm.org/D83080
This commit is contained in:
parent
7356b4243a
commit
6d6d5db251
|
@ -333,23 +333,28 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
|
||||||
|
|
||||||
template <typename IndexedValueType>
|
template <typename IndexedValueType>
|
||||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
|
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
|
||||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
|
||||||
// Emit scalar form.
|
// Emit scalar form.
|
||||||
Value lhs = std_load(op.output(), indices.outputs);
|
IndexedValueType output(op.output());
|
||||||
Value rhs = std_load(op.input(), indices.inputs);
|
IndexedValueType input(op.input());
|
||||||
|
Value lhs = output(indices.outputs);
|
||||||
|
Value rhs = input(indices.inputs);
|
||||||
using edsc::op::sgt;
|
using edsc::op::sgt;
|
||||||
Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
|
Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
|
||||||
std_store(maxValue, op.output(), indices.outputs);
|
output(indices.outputs) = maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename IndexedValueType>
|
template <typename IndexedValueType>
|
||||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
|
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
|
||||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
|
||||||
// Emit scalar form.
|
// Emit scalar form.
|
||||||
Value lhs = std_load(op.output(), indices.outputs);
|
IndexedValueType output(op.output());
|
||||||
Value rhs = std_load(op.input(), indices.inputs);
|
IndexedValueType input(op.input());
|
||||||
|
Value lhs = output(indices.outputs);
|
||||||
|
Value rhs = input(indices.inputs);
|
||||||
using edsc::op::slt;
|
using edsc::op::slt;
|
||||||
Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
|
Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
|
||||||
std_store(minValue, op.output(), indices.outputs);
|
output(indices.outputs) = minValue;
|
||||||
}
|
}
|
||||||
template <typename IndexedValueType>
|
template <typename IndexedValueType>
|
||||||
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
|
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
|
||||||
|
|
|
@ -123,3 +123,27 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
|
||||||
// CHECK: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
// CHECK: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||||
// CHECK: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
// CHECK: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||||
// CHECK: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
|
// CHECK: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @pooling_max_min
|
||||||
|
func @pooling_max_min(%arg0: memref<?x?xf32>,
|
||||||
|
%arg1: memref<?x?xi32>,
|
||||||
|
%arg2: memref<?x?xf32>) {
|
||||||
|
linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }:
|
||||||
|
memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
|
||||||
|
linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }:
|
||||||
|
memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// This is a basic check to make sure the right load/stores are used. loops.mlir
|
||||||
|
// checks for the rest.
|
||||||
|
// CHECK: affine.load
|
||||||
|
// CHECK-NEXT: affine.load
|
||||||
|
// CHECK-NEXT: cmpf
|
||||||
|
// CHECK-NEXT: select
|
||||||
|
// CHECK-NEXT: affine.store
|
||||||
|
// The min pooling body.
|
||||||
|
// CHECK: affine.load
|
||||||
|
// CHECK-NEXT: affine.load
|
||||||
|
// CHECK-NEXT: cmpf
|
||||||
|
// CHECK-NEXT: select
|
||||||
|
// CHECK-NEXT: affine.store
|
||||||
|
|
Loading…
Reference in New Issue