[mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic ops.

Fixing a bug where using a zero-rank shaped type operand to
linalg.generic ops hit an unrelated assert. This also meant that
lowering the operation to loops was not supported. Adding roundtrip
tests and lowering to loops test for zero-rank shaped type operand
with fixes to make the test pass.

Differential Revision: https://reviews.llvm.org/D74638
This commit is contained in:
MaheshRavishankar 2020-02-18 09:50:47 -08:00
parent 870c1fd4c8
commit a8355b5c0f
5 changed files with 69 additions and 8 deletions

View File

@ -361,11 +361,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
if (!cst || cst.getValue() != 0)
return op.emitOpError("expected indexing_map #")
<< idx << " to be 0 to match 0-D view: " << view;
}
if (m.getNumResults() != view.getRank())
} else if (m.getNumResults() != view.getRank()) {
return op.emitOpError("expected indexing_map #")
<< idx << " results to match view rank: " << view;
}
}
auto concatMap = concatAffineMaps(indexingMaps);

View File

@ -238,9 +238,14 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
Value input = genericOp.getInput(i);
if (!input.getType().cast<ShapedType>().getRank()) {
indexedValues[i] = std_load(input);
} else {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(input, indexing);
}
}
// 1.b. Emit std_load from output views.

View File

@ -351,12 +351,12 @@ AffineMap mlir::inversePermutation(AffineMap map) {
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
unsigned numResults = 0;
for (auto m : maps)
numResults += m ? m.getNumResults() : 0;
numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
unsigned numDims = 0;
SmallVector<AffineExpr, 8> results;
results.reserve(numResults);
for (auto m : maps) {
if (!m)
if (!m || m.isSingleConstant())
continue;
assert(m.getNumSymbols() == 0 && "expected map without symbols");
results.append(m.getResults().begin(), m.getResults().end());

View File

@ -356,3 +356,36 @@ func @indexed_generic_region(
// CHECK: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
// CHECK: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
// CHECK: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
// -----
#broadcast_access = [
affine_map<(i, j) -> (0)>,
affine_map<(i,j) -> (i,j)>
]
#trait_broadcast = {
args_in = 1,
args_out = 1,
indexing_maps = #broadcast_access,
iterator_types = ["parallel", "parallel"],
library_call = "some_broadcast_external_fn"
}
func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1: memref<3x4xf32>)
{
linalg.generic #trait_broadcast %arg0, %arg1 {
^bb(%a: f32, %b : f32) :
linalg.yield %a : f32
} : memref<f32>, memref<3x4xf32>
return
}
// CHECK-LABEL: @generic_op_zero_rank
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<f32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
// CHECK: loop.for %[[i:.*]] = {{.*}}
// CHECK: loop.for %[[j:.*]] = {{.*}}
// CHECK: %[[a:.*]] = load %[[ARG0]][]
// CHECK: store %[[a]], %[[ARG1]][%[[i]], %[[j]]]

View File

@ -345,6 +345,30 @@ func @indexed_generic_with_tensor_input_and_output(
// -----
#broadcast_access = [
affine_map<(i, j) -> (0)>,
affine_map<(i,j) -> (i,j)>
]
#trait_broadcast = {
args_in = 1,
args_out = 1,
indexing_maps = #broadcast_access,
iterator_types = ["parallel", "parallel"],
library_call = "some_broadcast_external_fn"
}
func @generic_op_zero_rank(%arg0 : tensor<f32>) -> (tensor<3x4xf32>)
{
%0 = linalg.generic #trait_broadcast %arg0 {
^bb(%a: f32) :
linalg.yield %a : f32
} : tensor<f32> -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
// -----
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>