forked from OSchip/llvm-project
[mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic ops.
Fixing a bug where using a zero-rank shaped type operand to linalg.generic ops hit an unrelated assert. This also meant that lowering the operation to loops was not supported. Adding roundtrip tests and lowering to loops test for zero-rank shaped type operand with fixes to make the test pass. Differential Revision: https://reviews.llvm.org/D74638
This commit is contained in:
parent
870c1fd4c8
commit
a8355b5c0f
|
@ -361,11 +361,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
if (!cst || cst.getValue() != 0)
|
||||
return op.emitOpError("expected indexing_map #")
|
||||
<< idx << " to be 0 to match 0-D view: " << view;
|
||||
}
|
||||
|
||||
if (m.getNumResults() != view.getRank())
|
||||
} else if (m.getNumResults() != view.getRank()) {
|
||||
return op.emitOpError("expected indexing_map #")
|
||||
<< idx << " results to match view rank: " << view;
|
||||
}
|
||||
}
|
||||
|
||||
auto concatMap = concatAffineMaps(indexingMaps);
|
||||
|
|
|
@ -238,9 +238,14 @@ public:
|
|||
|
||||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
|
||||
Value input = genericOp.getInput(i);
|
||||
if (!input.getType().cast<ShapedType>().getRank()) {
|
||||
indexedValues[i] = std_load(input);
|
||||
} else {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[i] = std_load(input, indexing);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
|
|
|
@ -351,12 +351,12 @@ AffineMap mlir::inversePermutation(AffineMap map) {
|
|||
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
|
||||
unsigned numResults = 0;
|
||||
for (auto m : maps)
|
||||
numResults += m ? m.getNumResults() : 0;
|
||||
numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
|
||||
unsigned numDims = 0;
|
||||
SmallVector<AffineExpr, 8> results;
|
||||
results.reserve(numResults);
|
||||
for (auto m : maps) {
|
||||
if (!m)
|
||||
if (!m || m.isSingleConstant())
|
||||
continue;
|
||||
assert(m.getNumSymbols() == 0 && "expected map without symbols");
|
||||
results.append(m.getResults().begin(), m.getResults().end());
|
||||
|
|
|
@ -356,3 +356,36 @@ func @indexed_generic_region(
|
|||
// CHECK: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
|
||||
// CHECK: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
|
||||
// CHECK: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
#broadcast_access = [
|
||||
affine_map<(i, j) -> (0)>,
|
||||
affine_map<(i,j) -> (i,j)>
|
||||
]
|
||||
|
||||
#trait_broadcast = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = #broadcast_access,
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
library_call = "some_broadcast_external_fn"
|
||||
}
|
||||
|
||||
func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1: memref<3x4xf32>)
|
||||
{
|
||||
linalg.generic #trait_broadcast %arg0, %arg1 {
|
||||
^bb(%a: f32, %b : f32) :
|
||||
linalg.yield %a : f32
|
||||
} : memref<f32>, memref<3x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generic_op_zero_rank
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<f32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
|
||||
// CHECK: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECK: loop.for %[[j:.*]] = {{.*}}
|
||||
// CHECK: %[[a:.*]] = load %[[ARG0]][]
|
||||
// CHECK: store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
|
||||
|
|
|
@ -345,6 +345,30 @@ func @indexed_generic_with_tensor_input_and_output(
|
|||
|
||||
// -----
|
||||
|
||||
#broadcast_access = [
|
||||
affine_map<(i, j) -> (0)>,
|
||||
affine_map<(i,j) -> (i,j)>
|
||||
]
|
||||
|
||||
#trait_broadcast = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = #broadcast_access,
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
library_call = "some_broadcast_external_fn"
|
||||
}
|
||||
|
||||
func @generic_op_zero_rank(%arg0 : tensor<f32>) -> (tensor<3x4xf32>)
|
||||
{
|
||||
%0 = linalg.generic #trait_broadcast %arg0 {
|
||||
^bb(%a: f32) :
|
||||
linalg.yield %a : f32
|
||||
} : tensor<f32> -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
|
||||
|
||||
|
|
Loading…
Reference in New Issue