forked from OSchip/llvm-project
[mlir:Async] Submit accidentally omitted changes
Accidentally pushed old branches that did not include all the changes discussed in the PRs. https://reviews.llvm.org/rGd43b23608ad664f02f56e965ca78916bde220950 https://reviews.llvm.org/rG86ad0af87054c3cccd68d32e103a6f1f6c6194c7 Differential Revision: https://reviews.llvm.org/D104943
This commit is contained in:
parent
5b2573e9c7
commit
34a164c938
|
@ -19,6 +19,10 @@ namespace mlir {
|
|||
|
||||
std::unique_ptr<Pass> createAsyncParallelForPass();
|
||||
|
||||
std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
|
||||
int32_t numWorkerThreads,
|
||||
int32_t targetBlockSize);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
|
||||
|
||||
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
|
||||
|
|
|
@ -596,7 +596,7 @@ public:
|
|||
matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TypeConverter *converter = getTypeConverter();
|
||||
Type resultType = op->getResultTypes()[0];
|
||||
Type resultType = op.getResult().getType();
|
||||
|
||||
rewriter.replaceOpWithNewOp<CallOp>(
|
||||
op, kCreateGroup, converter->convertType(resultType), operands);
|
||||
|
|
|
@ -90,6 +90,14 @@ namespace {
|
|||
struct AsyncParallelForPass
|
||||
: public AsyncParallelForBase<AsyncParallelForPass> {
|
||||
AsyncParallelForPass() = default;
|
||||
|
||||
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
|
||||
int32_t targetBlockSize) {
|
||||
this->asyncDispatch = asyncDispatch;
|
||||
this->numWorkerThreads = numWorkerThreads;
|
||||
this->targetBlockSize = targetBlockSize;
|
||||
}
|
||||
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
|
@ -127,7 +135,7 @@ struct ParallelComputeFunction {
|
|||
// Converts one-dimensional iteration index in the [0, tripCount) interval
|
||||
// into multidimensional iteration coordinate.
|
||||
static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
|
||||
const SmallVector<Value> &tripCounts) {
|
||||
ArrayRef<Value> tripCounts) {
|
||||
SmallVector<Value> coords(tripCounts.size());
|
||||
assert(!tripCounts.empty() && "tripCounts must be not empty");
|
||||
|
||||
|
@ -184,7 +192,6 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
b.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
ParallelComputeFunctionType computeFuncType =
|
||||
getParallelComputeFunctionType(op, rewriter);
|
||||
|
@ -204,12 +211,13 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
|
||||
unsigned offset = 0; // argument offset for arguments decoding
|
||||
|
||||
// Load multiple arguments into values vector.
|
||||
auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> {
|
||||
SmallVector<Value> values(num_arguments);
|
||||
for (unsigned i = 0; i < num_arguments; ++i)
|
||||
values[i] = block->getArgument(offset++);
|
||||
return values;
|
||||
// Returns `numArguments` arguments starting from `offset` and updates offset
|
||||
// by moving forward to the next argument.
|
||||
auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
|
||||
auto args = block->getArguments();
|
||||
auto slice = args.drop_front(offset).take_front(numArguments);
|
||||
offset += numArguments;
|
||||
return {slice.begin(), slice.end()};
|
||||
};
|
||||
|
||||
// Block iteration position defined by the block index and size.
|
||||
|
@ -217,11 +225,11 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
Value blockSize = block->getArgument(offset++);
|
||||
|
||||
// Constants used below.
|
||||
Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
|
||||
Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
|
||||
Value c0 = b.create<ConstantIndexOp>(0);
|
||||
Value c1 = b.create<ConstantIndexOp>(1);
|
||||
|
||||
// Multi-dimensional parallel iteration space defined by the loop trip counts.
|
||||
SmallVector<Value> tripCounts = getArguments(op.getNumLoops());
|
||||
ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
|
||||
|
||||
// Compute a product of trip counts to get the size of the flattened
|
||||
// one-dimensional iteration space.
|
||||
|
@ -229,35 +237,34 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
for (unsigned i = 1; i < tripCounts.size(); ++i)
|
||||
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
|
||||
|
||||
// Parallel operation lower bound, upper bound and step.
|
||||
SmallVector<Value> lowerBound = getArguments(op.getNumLoops());
|
||||
SmallVector<Value> upperBound = getArguments(op.getNumLoops());
|
||||
SmallVector<Value> step = getArguments(op.getNumLoops());
|
||||
// Parallel operation lower bound and step.
|
||||
ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
|
||||
offset += op.getNumLoops(); // skip upper bound arguments
|
||||
ArrayRef<Value> step = getArguments(op.getNumLoops());
|
||||
|
||||
// Remaining arguments are implicit captures of the parallel operation.
|
||||
SmallVector<Value> captures = getArguments(block->getNumArguments() - offset);
|
||||
ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
|
||||
|
||||
// Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
|
||||
// blockFirstIndex = blockIndex * blockSize
|
||||
Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
|
||||
|
||||
// The last one-dimensional index in the block defined by the `blockIndex`:
|
||||
// blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1
|
||||
Value blockEnd0 = b.create<AddIOp>(blockIndex, c1);
|
||||
Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize);
|
||||
Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount);
|
||||
Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1);
|
||||
Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1);
|
||||
// blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
|
||||
Value blockEnd0 = b.create<AddIOp>(blockFirstIndex, blockSize);
|
||||
Value blockEnd1 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd0, tripCount);
|
||||
Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
|
||||
Value blockLastIndex = b.create<SubIOp>(blockEnd2, c1);
|
||||
|
||||
// Convert one-dimensional indices to multi-dimensional coordinates.
|
||||
auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
|
||||
auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
|
||||
|
||||
// Compute compute loops upper bounds from the block last coordinates:
|
||||
// Compute loops upper bounds derived from the block last coordinates:
|
||||
// blockEndCoord[i] = blockLastCoord[i] + 1
|
||||
//
|
||||
// Block first and last coordinates can be the same along the outer compute
|
||||
// dimension when inner compute dimension containts multple blocks.
|
||||
// dimension when inner compute dimension contains multiple blocks.
|
||||
SmallVector<Value> blockEndCoord(op.getNumLoops());
|
||||
for (size_t i = 0; i < blockLastCoord.size(); ++i)
|
||||
blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
|
||||
|
@ -312,7 +319,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
isBlockLastCoord[loopIdx] =
|
||||
nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
|
||||
|
||||
// Check if the previous loop is in its first of last iteration.
|
||||
// Check if the previous loop is in its first or last iteration.
|
||||
if (loopIdx > 0) {
|
||||
isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
|
||||
isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
|
||||
|
@ -380,7 +387,6 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
|
|||
ImplicitLocOpBuilder b(loc, rewriter);
|
||||
|
||||
ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
|
||||
b.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
ArrayRef<Type> computeFuncInputTypes =
|
||||
computeFunc.func.type().cast<FunctionType>().getInputs();
|
||||
|
@ -408,8 +414,8 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
|
|||
b.setInsertionPointToEnd(block);
|
||||
|
||||
Type indexTy = b.getIndexType();
|
||||
Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
|
||||
Value c2 = b.create<ConstantOp>(b.getIndexAttr(2));
|
||||
Value c1 = b.create<ConstantIndexOp>(1);
|
||||
Value c2 = b.create<ConstantIndexOp>(2);
|
||||
|
||||
// Get the async group that will track async dispatch completion.
|
||||
Value group = block->getArgument(0);
|
||||
|
@ -439,14 +445,14 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
|
|||
}
|
||||
|
||||
// Setup the async dispatch loop body: recursively call dispatch function
|
||||
// for second the half of the original range and go to the next iteration.
|
||||
// for the seconds half of the original range and go to the next iteration.
|
||||
{
|
||||
b.setInsertionPointToEnd(after);
|
||||
Value start = after->getArgument(0);
|
||||
Value end = after->getArgument(1);
|
||||
Value distance = b.create<SubIOp>(end, start);
|
||||
Value halfDistance = b.create<SignedDivIOp>(distance, c2);
|
||||
Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance);
|
||||
Value midIndex = b.create<AddIOp>(start, halfDistance);
|
||||
|
||||
// Call parallel compute function inside the async.execute region.
|
||||
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
|
||||
|
@ -466,7 +472,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
|
|||
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
|
||||
executeBodyBuilder);
|
||||
b.create<AddToGroupOp>(indexTy, execute.token(), group);
|
||||
b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex}));
|
||||
b.create<scf::YieldOp>(ValueRange({start, midIndex}));
|
||||
}
|
||||
|
||||
// After dispatching async operations to process the tail of the block range
|
||||
|
@ -498,8 +504,8 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
|
|||
FuncOp asyncDispatchFunction =
|
||||
createAsyncDispatchFunction(parallelComputeFunction, rewriter);
|
||||
|
||||
Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
|
||||
Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
|
||||
Value c0 = b.create<ConstantIndexOp>(0);
|
||||
Value c1 = b.create<ConstantIndexOp>(1);
|
||||
|
||||
// Create an async.group to wait on all async tokens from the concurrent
|
||||
// execution of multiple parallel compute function. First block will be
|
||||
|
@ -535,8 +541,8 @@ doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
|
|||
|
||||
FuncOp compute = parallelComputeFunction.func;
|
||||
|
||||
Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
|
||||
Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
|
||||
Value c0 = b.create<ConstantIndexOp>(0);
|
||||
Value c1 = b.create<ConstantIndexOp>(1);
|
||||
|
||||
// Create an async.group to wait on all async tokens from the concurrent
|
||||
// execution of multiple parallel compute function. First block will be
|
||||
|
@ -617,19 +623,16 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
for (size_t i = 1; i < tripCounts.size(); ++i)
|
||||
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
|
||||
|
||||
auto indexTy = b.getIndexType();
|
||||
|
||||
// Do not overload worker threads with too many compute blocks.
|
||||
Value maxComputeBlocks = b.create<ConstantOp>(
|
||||
indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding));
|
||||
Value maxComputeBlocks =
|
||||
b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
|
||||
|
||||
// Target block size from the pass parameters.
|
||||
Value targetComputeBlockSize =
|
||||
b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize));
|
||||
Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
|
||||
|
||||
// Compute parallel block size from the parallel problem size:
|
||||
// blockSize = min(tripCount,
|
||||
// max(divup(tripCount, maxComputeBlocks),
|
||||
// max(ceil_div(tripCount, maxComputeBlocks),
|
||||
// targetComputeBlockSize))
|
||||
Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
|
||||
Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
|
||||
|
@ -653,7 +656,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
blockCount, tripCounts);
|
||||
}
|
||||
|
||||
// Parallel operation was replaces with a block iteration loop.
|
||||
// Parallel operation was replaced with a block iteration loop.
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
|
@ -673,3 +676,10 @@ void AsyncParallelForPass::runOnOperation() {
|
|||
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
|
||||
return std::make_unique<AsyncParallelForPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
|
||||
int32_t targetBlockSize) {
|
||||
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
|
||||
targetBlockSize);
|
||||
}
|
||||
|
|
|
@ -18,18 +18,33 @@ func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
|
|||
// CHECK: memref.store
|
||||
|
||||
// CHECK-LABEL: func private @async_dispatch_fn
|
||||
// CHECK-SAME: (
|
||||
// CHECK-SAME: %[[GROUP:arg0]]: !async.group,
|
||||
// CHECK-SAME: %[[BLOCK_START:arg1]]: index
|
||||
// CHECK-SAME: %[[BLOCK_END:arg2]]: index
|
||||
|
||||
// CHECK: scf.while (%[[S:.*]] = %[[BLOCK_START]],
|
||||
// CHECK-SAME: %[[E:.*]] = %[[BLOCK_END]])
|
||||
// CHECK-SAME: )
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: scf.while (%[[S0:.*]] = %[[BLOCK_START]],
|
||||
// CHECK-SAME: %[[E0:.*]] = %[[BLOCK_END]])
|
||||
// While loop `before` block decides if we need to dispatch more tasks.
|
||||
// CHECK: {
|
||||
// CHECK: %[[DIFF0:.*]] = subi %[[E0]], %[[S0]]
|
||||
// CHECK: %[[COND:.*]] = cmpi sgt, %[[DIFF0]], %[[C1]]
|
||||
// CHECK: scf.condition(%[[COND]])
|
||||
// While loop `after` block splits the range in half and submits async task
|
||||
// to process the second half using the call to the same dispatch function.
|
||||
// CHECK: } do {
|
||||
// CHECK: ^bb0(%[[S1:.*]]: index, %[[E1:.*]]: index):
|
||||
// CHECK: %[[DIFF1:.*]] = subi %[[E1]], %[[S1]]
|
||||
// CHECK: %[[HALF:.*]] = divi_signed %[[DIFF1]], %[[C2]]
|
||||
// CHECK: %[[MID:.*]] = addi %[[S1]], %[[HALF]]
|
||||
// CHECK: %[[TOKEN:.*]] = async.execute
|
||||
// CHECK: call @async_dispatch_fn
|
||||
// CHECK: async.add_to_group
|
||||
// CHECK: async.add_to_group
|
||||
// CHECK: scf.yield %[[S1]], %[[MID]]
|
||||
// CHECK: }
|
||||
|
||||
// After async dispatch the first block processed in the caller thread.
|
||||
// CHECK: call @parallel_compute_fn(%[[BLOCK_START]]
|
||||
|
||||
// -----
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
// RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=false \
|
||||
// RUN: | FileCheck %s --dump-input=always
|
||||
|
||||
// The structure of @parallel_compute_fn checked in the async dispatch test.
|
||||
// Here we only check the structure of the sequential dispatch loop.
|
||||
|
||||
// CHECK-LABEL: @loop_1d
|
||||
func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
|
||||
// CHECK: %[[GROUP:.*]] = async.create_group
|
||||
|
|
Loading…
Reference in New Issue