[mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops

The invertPermutation method does not return a nullptr anymore, but
rather returns an empty map for the scalar case. Update the check in
LinalgToLoops to reflect this.
Also add test case for generating scalar code.
This commit is contained in:
MaheshRavishankar 2020-04-11 23:01:40 -07:00
parent 03391df90e
commit 3b2f26ab05
2 changed files with 45 additions and 2 deletions

View File

@ -652,8 +652,8 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps =
functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
auto invertedMap = inversePermutation(concatAffineMaps(maps));
if (!invertedMap) {
AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
if (invertedMap.isEmpty()) {
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
{}, linalgOp);
return LinalgLoops();

View File

@ -913,3 +913,46 @@ func @generic_const_init(%arg0: memref<?xf32>) {
// CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32
// CHECKPARALLEL: loop.parallel (%[[i:.*]])
// CHECKPARALLEL: store %[[CONST]], %[[ARG0]]
#scalar_access = [
affine_map<() -> ()>,
affine_map<() -> ()>,
affine_map<() -> ()>
]
#scalar_trait = {
args_in = 2,
args_out = 1,
iterator_types = [],
indexing_maps = #scalar_access,
library_call = "some_external_fn"
}
func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
{
linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
^bb(%a : f32, %b : f32, %c : f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} : memref<f32>, memref<f32>, memref<f32>
return
}
// CHECKLOOP-LABEL: @scalar_code
// CHECKLOOP-SAME: %[[ARG0]]: memref<f32>
// CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
// CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
// CHECKLOOP-NOT: loop.for
// CHECKLOOP-DAG: load %[[ARG0]][]
// CHECKLOOP-DAG: load %[[ARG1]][]
// CHECKLOOP-DAG: load %[[ARG2]][]
// CHECKLOOP: addf
// CHECKLOOP: store %{{.*}}, %[[ARG2]][]
// CHECKPARALLEL-LABEL: @scalar_code
// CHECKPARALLEL-SAME: %[[ARG0]]: memref<f32>
// CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
// CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
// CHECKPARALLEL-NOT: loop.for
// CHECKPARALLEL-DAG: load %[[ARG0]][]
// CHECKPARALLEL-DAG: load %[[ARG1]][]
// CHECKPARALLEL-DAG: load %[[ARG2]][]
// CHECKPARALLEL: addf
// CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]