forked from OSchip/llvm-project
[mlir][Linalg] Fix load/store operations generated while lower loops when
output has zero rank. While lowering to loops, no indices should be used in the load/store operation if the buffer is zero-rank. Differential Revision: https://reviews.llvm.org/D75391
This commit is contained in:
parent
f708c823f0
commit
755c050200
|
@ -242,21 +242,25 @@ public:
|
|||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
Value input = genericOp.getInput(i);
|
||||
if (!input.getType().cast<ShapedType>().getRank()) {
|
||||
indexedValues[i] = std_load(input);
|
||||
} else {
|
||||
if (input.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[i] = std_load(input, indexing);
|
||||
} else {
|
||||
indexedValues[i] = std_load(input);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nInputs + i] =
|
||||
std_load(genericOp.getOutputBuffer(i), indexing);
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nInputs + i] = std_load(output, indexing);
|
||||
} else {
|
||||
indexedValues[nInputs + i] = std_load(output);
|
||||
}
|
||||
}
|
||||
|
||||
auto funcOp = genericOp.getFunction();
|
||||
|
@ -267,9 +271,14 @@ public:
|
|||
|
||||
// 3. Emit std_store.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
} else {
|
||||
std_store(callOp->getResult(i), output);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -288,10 +297,15 @@ public:
|
|||
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
|
||||
assert(yieldOp->getNumOperands() == nOutputs);
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)),
|
||||
genericOp.getOutputBuffer(i), indexing);
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)),
|
||||
genericOp.getOutputBuffer(i), indexing);
|
||||
} else {
|
||||
std_store(map.lookup(yieldOp->getOperand(i)), output);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -348,21 +362,25 @@ public:
|
|||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
Value input = indexedGenericOp.getInput(i);
|
||||
if (!input.getType().cast<ShapedType>().getRank()) {
|
||||
indexedValues[nLoops + i] = std_load(input);
|
||||
} else {
|
||||
if (input.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + i] = std_load(input, indexing);
|
||||
} else {
|
||||
indexedValues[nLoops + i] = std_load(input);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + nInputs + i] =
|
||||
std_load(indexedGenericOp.getOutputBuffer(i), indexing);
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
|
||||
} else {
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto funcOp = indexedGenericOp.getFunction()) {
|
||||
|
@ -372,10 +390,14 @@ public:
|
|||
|
||||
// 3. Emit std_store.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
|
||||
indexing);
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
} else {
|
||||
std_store(callOp->getResult(i), output);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -394,10 +416,14 @@ public:
|
|||
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
|
||||
assert(yieldOp->getNumOperands() == nOutputs);
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)),
|
||||
indexedGenericOp.getOutputBuffer(i), indexing);
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
|
||||
} else {
|
||||
std_store(map.lookup(yieldOp->getOperand(i)), output);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -411,3 +411,75 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
|
|||
// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
|
||||
// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
|
||||
// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
|
||||
|
||||
#reduce_1D_access = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (0)>
|
||||
]
|
||||
|
||||
#trait_reduce_1D = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = #reduce_1D_access,
|
||||
iterator_types = ["reduction"],
|
||||
library_call = "some_reduce_external_fn"
|
||||
}
|
||||
|
||||
func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
|
||||
{
|
||||
linalg.generic #trait_reduce_1D %arg0, %arg1 {
|
||||
^bb(%a: f32, %b: f32) :
|
||||
%0 = addf %a, %b : f32
|
||||
linalg.yield %0 : f32
|
||||
} : memref<?xf32>, memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: @generic_op_1D_reduce
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
|
||||
// CHECK: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
|
||||
// CHECK: %[[b:.*]] = load %[[ARG1]][]
|
||||
// CHECK: %[[c:.*]] = addf %[[a]], %[[b]] : f32
|
||||
// CHECK: store %[[c]], %[[ARG1]][]
|
||||
|
||||
|
||||
#reduce_init_1D_access = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (0)>,
|
||||
affine_map<(i) -> (0)>
|
||||
]
|
||||
|
||||
#trait_reduce_init_1D = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = #reduce_init_1D_access,
|
||||
iterator_types = ["reduction"],
|
||||
library_call = "some_reduce_external_fn"
|
||||
}
|
||||
|
||||
func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
|
||||
%arg1: memref<f32>,
|
||||
%arg2: memref<f32>)
|
||||
{
|
||||
linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
|
||||
^bb(%i : index, %a: f32, %b: f32, %c: f32) :
|
||||
%0 = constant 0 : index
|
||||
%1 = cmpi "eq", %0, %i : index
|
||||
%2 = select %1, %b, %c : f32
|
||||
%3 = addf %a, %2 : f32
|
||||
linalg.yield %3 : f32
|
||||
} : memref<?xf32>, memref<f32>, memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: @indexed_generic_op_1D_reduce
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
|
||||
// CHECK: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
|
||||
// CHECK: %[[b:.*]] = load %[[ARG1]][]
|
||||
// CHECK: %[[c:.*]] = load %[[ARG2]][]
|
||||
// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
|
||||
// CHECK: %[[e:.*]] = addf %[[a]], %[[d]]
|
||||
// CHECK: store %[[e]], %[[ARG2]][]
|
||||
|
|
Loading…
Reference in New Issue