[mlir][Linalg] Fix load/store operations generated while lower loops when

output has zero rank.

While lowering to loops, no indices should be used in the load/store
operation if the buffer is zero-rank.

Differential Revision: https://reviews.llvm.org/D75391
This commit is contained in:
MaheshRavishankar 2020-03-04 17:03:07 -08:00
parent f708c823f0
commit 755c050200
2 changed files with 127 additions and 29 deletions

View File

@ -242,21 +242,25 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = genericOp.getInput(i);
if (!input.getType().cast<ShapedType>().getRank()) {
indexedValues[i] = std_load(input);
} else {
if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(input, indexing);
} else {
indexedValues[i] = std_load(input);
}
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] =
std_load(genericOp.getOutputBuffer(i), indexing);
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] = std_load(output, indexing);
} else {
indexedValues[nInputs + i] = std_load(output);
}
}
auto funcOp = genericOp.getFunction();
@ -267,9 +271,14 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
} else {
std_store(callOp->getResult(i), output);
}
}
return;
}
@ -288,10 +297,15 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
genericOp.getOutputBuffer(i), indexing);
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
genericOp.getOutputBuffer(i), indexing);
} else {
std_store(map.lookup(yieldOp->getOperand(i)), output);
}
}
}
};
@ -348,21 +362,25 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = indexedGenericOp.getInput(i);
if (!input.getType().cast<ShapedType>().getRank()) {
indexedValues[nLoops + i] = std_load(input);
} else {
if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] = std_load(input, indexing);
} else {
indexedValues[nLoops + i] = std_load(input);
}
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] =
std_load(indexedGenericOp.getOutputBuffer(i), indexing);
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
} else {
indexedValues[nLoops + nInputs + i] = std_load(output);
}
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@ -372,10 +390,14 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
indexing);
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
} else {
std_store(callOp->getResult(i), output);
}
}
return;
}
@ -394,10 +416,14 @@ public:
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
indexedGenericOp.getOutputBuffer(i), indexing);
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
} else {
std_store(map.lookup(yieldOp->getOperand(i)), output);
}
}
}
};

View File

@ -411,3 +411,75 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
#reduce_1D_access = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (0)>
]
#trait_reduce_1D = {
args_in = 1,
args_out = 1,
indexing_maps = #reduce_1D_access,
iterator_types = ["reduction"],
library_call = "some_reduce_external_fn"
}
func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
{
linalg.generic #trait_reduce_1D %arg0, %arg1 {
^bb(%a: f32, %b: f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} : memref<?xf32>, memref<f32>
return
}
// CHECK-LABEL: @generic_op_1D_reduce
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
// CHECK: loop.for %[[i:.*]] = {{.*}}
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
// CHECK: %[[b:.*]] = load %[[ARG1]][]
// CHECK: %[[c:.*]] = addf %[[a]], %[[b]] : f32
// CHECK: store %[[c]], %[[ARG1]][]
#reduce_init_1D_access = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (0)>,
affine_map<(i) -> (0)>
]
#trait_reduce_init_1D = {
args_in = 2,
args_out = 1,
indexing_maps = #reduce_init_1D_access,
iterator_types = ["reduction"],
library_call = "some_reduce_external_fn"
}
func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
%arg1: memref<f32>,
%arg2: memref<f32>)
{
linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
^bb(%i : index, %a: f32, %b: f32, %c: f32) :
%0 = constant 0 : index
%1 = cmpi "eq", %0, %i : index
%2 = select %1, %b, %c : f32
%3 = addf %a, %2 : f32
linalg.yield %3 : f32
} : memref<?xf32>, memref<f32>, memref<f32>
return
}
// CHECK-LABEL: @indexed_generic_op_1D_reduce
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
// CHECK: loop.for %[[i:.*]] = {{.*}}
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
// CHECK: %[[b:.*]] = load %[[ARG1]][]
// CHECK: %[[c:.*]] = load %[[ARG2]][]
// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
// CHECK: %[[e:.*]] = addf %[[a]], %[[d]]
// CHECK: store %[[e]], %[[ARG2]][]