forked from OSchip/llvm-project
[SCF] Add thread_dim_mapping attribute to scf.foreach_thread
An optional thread_dim_mapping index array attribute specifies for each virtual thread dimension, how it remaps 1-1 to a set of concrete processing element resources (e.g. a CUDA grid dimension or a level of concrete nested async parallelism). At this time, the specification is backend-dependent and is not verified by the op, beyond being an index array attribute. It is the reponsibility of the lowering to interpret the index array in the context of the concrete target the op is lowered to, or to ignore it when the specification is ill-formed or unsupported for a particular target. Differential Revision: https://reviews.llvm.org/D128633
This commit is contained in:
parent
5d50f51c97
commit
a0f843fdaf
|
@ -339,6 +339,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
|
|||
application per thread. Further lowerings are responsible for specifying
|
||||
how this is materialized on concrete hardware resources.
|
||||
|
||||
An optional thread_dim_mapping index array attribute specifies for each
|
||||
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
|
||||
element resources (e.g. a CUDA grid dimension or a level of concrete nested
|
||||
async parallelism). At this time, the specification is backend-dependent and
|
||||
is not verified by the op, beyond being an index array attribute.
|
||||
It is the reponsibility of the lowering to interpret the index array in the
|
||||
context of the concrete target the op is lowered to, or to ignore it when
|
||||
the specification is ill-formed or unsupported for a particular target.
|
||||
|
||||
The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
|
||||
which dictates how the partial results of all parallel invocations should be
|
||||
reconciled into a full value.
|
||||
|
@ -398,8 +407,27 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
|
|||
// Sequential context.
|
||||
//
|
||||
```
|
||||
|
||||
Example with thread_dim_mapping attribute:
|
||||
//
|
||||
// Sequential context.
|
||||
//
|
||||
%matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
|
||||
(%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
|
||||
//
|
||||
// Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
|
||||
// runs its version of the code.
|
||||
//
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
...
|
||||
}
|
||||
} { thread_dim_mapping = [1, 0] }
|
||||
// Implicit synchronization point.
|
||||
// Sequential context.
|
||||
//
|
||||
}];
|
||||
let arguments = (ins Variadic<Index>:$num_threads);
|
||||
let arguments = (ins Variadic<Index>:$num_threads,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
|
||||
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
@ -411,11 +439,13 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
|
|||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
// Bodyless builder, result types must be specified.
|
||||
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
|
||||
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
|
||||
CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
|
||||
// Builder that takes a bodyBuilder lambda, result types are inferred from
|
||||
// the terminator.
|
||||
OpBuilder<(ins "ValueRange":$num_threads,
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
|
||||
"ArrayRef<int64_t>":$thread_dim_mapping,
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
int64_t getRank() { return getNumThreads().size(); }
|
||||
|
|
|
@ -1135,8 +1135,12 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
|
|||
// Bodyless builder, result types must be specified.
|
||||
void ForeachThreadOp::build(mlir::OpBuilder &builder,
|
||||
mlir::OperationState &result, TypeRange resultTypes,
|
||||
ValueRange numThreads) {
|
||||
ValueRange numThreads,
|
||||
ArrayRef<int64_t> threadDimMapping) {
|
||||
result.addOperands(numThreads);
|
||||
result.addAttribute(
|
||||
// TODO: getThreadDimMappingAttrName() but it is not a static member.
|
||||
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
|
||||
|
||||
Region *bodyRegion = result.addRegion();
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
|
@ -1156,9 +1160,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
|
|||
// the terminator.
|
||||
void ForeachThreadOp::build(
|
||||
mlir::OpBuilder &builder, mlir::OperationState &result,
|
||||
ValueRange numThreads,
|
||||
ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
|
||||
result.addOperands(numThreads);
|
||||
result.addAttribute(
|
||||
// TODO: getThreadDimMappingAttrName() but it is not a static member.
|
||||
"thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
|
||||
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
Region *bodyRegion = result.addRegion();
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -999,7 +1000,8 @@ struct ForeachThreadOpInterface
|
|||
TypeRange newResultTypes;
|
||||
auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
|
||||
foreachThreadOp.getLoc(), newResultTypes,
|
||||
foreachThreadOp.getNumThreads());
|
||||
foreachThreadOp.getNumThreads(),
|
||||
extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
|
||||
newForeachThreadOp.getBody()->getTerminator()->erase();
|
||||
|
||||
// Move over block contents of the old op.
|
||||
|
|
|
@ -130,6 +130,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
|
|||
scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
|
||||
tensor<1xf32> into tensor<100xf32>
|
||||
}
|
||||
}
|
||||
// CHECK: } {thread_dim_mapping = [5]}
|
||||
} {thread_dim_mapping = [5]}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -338,11 +338,11 @@ func.func @elide_terminator() -> () {
|
|||
%num_threads = arith.constant 100 : index
|
||||
|
||||
// CHECK: scf.foreach_thread
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } {thread_dim_mapping = [42]}
|
||||
// CHECK-NEXT: return
|
||||
scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
|
||||
scf.foreach_thread.perform_concurrently {
|
||||
}
|
||||
}
|
||||
} {thread_dim_mapping = [42]}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue