[SCF] Add thread_dim_mapping attribute to scf.foreach_thread

An optional thread_dim_mapping index array attribute specifies for each
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
element resources (e.g. a CUDA grid dimension or a level of concrete nested
async parallelism). At this time, the specification is backend-dependent and
is not verified by the op, beyond being an index array attribute.
It is the reponsibility of the lowering to interpret the index array in the
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.

Differential Revision: https://reviews.llvm.org/D128633
This commit is contained in:
Nicolas Vasilache 2022-06-24 02:26:22 -07:00
parent 5d50f51c97
commit a0f843fdaf
5 changed files with 49 additions and 9 deletions

View File

@ -339,6 +339,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
application per thread. Further lowerings are responsible for specifying
how this is materialized on concrete hardware resources.
An optional thread_dim_mapping index array attribute specifies for each
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
element resources (e.g. a CUDA grid dimension or a level of concrete nested
async parallelism). At this time, the specification is backend-dependent and
is not verified by the op, beyond being an index array attribute.
It is the reponsibility of the lowering to interpret the index array in the
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.
The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
which dictates how the partial results of all parallel invocations should be
reconciled into a full value.
@ -398,8 +407,27 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// Sequential context.
//
```
Example with thread_dim_mapping attribute:
//
// Sequential context.
//
%matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
(%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
//
// Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
// runs its version of the code.
//
scf.foreach_thread.perform_concurrently {
...
}
} { thread_dim_mapping = [1, 0] }
// Implicit synchronization point.
// Sequential context.
//
}];
let arguments = (ins Variadic<Index>:$num_threads);
let arguments = (ins Variadic<Index>:$num_threads,
DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@ -411,11 +439,13 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
let skipDefaultBuilders = 1;
let builders = [
// Bodyless builder, result types must be specified.
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
// Builder that takes a bodyBuilder lambda, result types are inferred from
// the terminator.
OpBuilder<(ins "ValueRange":$num_threads,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
"ArrayRef<int64_t>":$thread_dim_mapping,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
];
let extraClassDeclaration = [{
int64_t getRank() { return getNumThreads().size(); }

View File

@ -1135,8 +1135,12 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
// Bodyless builder, result types must be specified.
void ForeachThreadOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, TypeRange resultTypes,
ValueRange numThreads) {
ValueRange numThreads,
ArrayRef<int64_t> threadDimMapping) {
result.addOperands(numThreads);
result.addAttribute(
// TODO: getThreadDimMappingAttrName() but it is not a static member.
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
Region *bodyRegion = result.addRegion();
OpBuilder::InsertionGuard g(builder);
@ -1156,9 +1160,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
// the terminator.
void ForeachThreadOp::build(
mlir::OpBuilder &builder, mlir::OperationState &result,
ValueRange numThreads,
ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
result.addOperands(numThreads);
result.addAttribute(
// TODO: getThreadDimMappingAttrName() but it is not a static member.
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
OpBuilder::InsertionGuard g(builder);
Region *bodyRegion = result.addRegion();

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@ -999,7 +1000,8 @@ struct ForeachThreadOpInterface
TypeRange newResultTypes;
auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
foreachThreadOp.getLoc(), newResultTypes,
foreachThreadOp.getNumThreads());
foreachThreadOp.getNumThreads(),
extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
newForeachThreadOp.getBody()->getTerminator()->erase();
// Move over block contents of the old op.

View File

@ -130,6 +130,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
}
}
// CHECK: } {thread_dim_mapping = [5]}
} {thread_dim_mapping = [5]}
return
}

View File

@ -338,11 +338,11 @@ func.func @elide_terminator() -> () {
%num_threads = arith.constant 100 : index
// CHECK: scf.foreach_thread
// CHECK-NEXT: }
// CHECK-NEXT: } {thread_dim_mapping = [42]}
// CHECK-NEXT: return
scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
scf.foreach_thread.perform_concurrently {
}
}
} {thread_dim_mapping = [42]}
return
}