[mlir][openacc] Add missing attributes and operands for acc.loop

This patch add the missing operands to the acc.loop operation. Only the device_type
information is not part of the operation for now.

Reviewed By: rriddle, kiranchandramohan

Differential Revision: https://reviews.llvm.org/D86753
This commit is contained in:
Valentin Clement 2020-08-31 19:49:44 -04:00 committed by clementval
parent f862d85807
commit 2bbbcae782
4 changed files with 174 additions and 34 deletions

View File

@ -31,12 +31,11 @@ namespace acc {
/// 2.9.2. gang /// 2.9.2. gang
/// 2.9.3. worker /// 2.9.3. worker
/// 2.9.4. vector /// 2.9.4. vector
/// 2.9.5. seq
/// ///
/// Value can be combined bitwise to reflect the mapping applied to the /// Value can be combined bitwise to reflect the mapping applied to the
/// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be /// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
/// combined and the final mapping value would be 5 (4 | 1). /// combined and the final mapping value would be 5 (4 | 1).
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 }; enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 };
} // end namespace acc } // end namespace acc
} // end namespace mlir } // end namespace mlir

View File

@ -224,6 +224,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let arguments = (ins OptionalAttr<I64Attr>:$collapse, let arguments = (ins OptionalAttr<I64Attr>:$collapse,
Optional<AnyInteger>:$gangNum,
Optional<AnyInteger>:$gangStatic,
Optional<AnyInteger>:$workerNum,
Optional<AnyInteger>:$vectorLength,
UnitAttr:$loopSeq,
UnitAttr:$loopIndependent,
UnitAttr:$loopAuto,
Variadic<AnyInteger>:$tileOperands,
Variadic<AnyType>:$privateOperands, Variadic<AnyType>:$privateOperands,
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp, OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands); Variadic<AnyType>:$reductionOperands);
@ -234,11 +242,16 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static StringRef getCollapseAttrName() { return "collapse"; } static StringRef getCollapseAttrName() { return "collapse"; }
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
static StringRef getGangAttrName() { return "gang"; }
static StringRef getSeqAttrName() { return "seq"; } static StringRef getSeqAttrName() { return "seq"; }
static StringRef getVectorAttrName() { return "vector"; } static StringRef getIndependentAttrName() { return "independent"; }
static StringRef getWorkerAttrName() { return "worker"; } static StringRef getAutoAttrName() { return "auto"; }
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
static StringRef getGangKeyword() { return "gang"; }
static StringRef getGangNumKeyword() { return "num"; }
static StringRef getGangStaticKeyword() { return "static"; }
static StringRef getVectorKeyword() { return "vector"; }
static StringRef getWorkerKeyword() { return "worker"; }
static StringRef getTileKeyword() { return "tile"; }
static StringRef getPrivateKeyword() { return "private"; } static StringRef getPrivateKeyword() { return "private"; }
static StringRef getReductionKeyword() { return "reduction"; } static StringRef getReductionKeyword() { return "reduction"; }
}]; }];

View File

@ -476,32 +476,81 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Parse acc.loop operation /// Parse acc.loop operation
/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`? /// operation := `acc.loop` `gang`? `vector`? `worker`?
/// `private` `(` value-list `)`? /// `private` `(` value-list `)`?
/// `reduction` `(` value-list `)`? /// `reduction` `(` value-list `)`?
/// region attr-dict? /// region attr-dict?
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder(); Builder &builder = parser.getBuilder();
unsigned executionMapping = 0; unsigned executionMapping = 0;
SmallVector<Type, 8> operandTypes; SmallVector<Type, 8> operandTypes;
SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands; SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
SmallVector<OpAsmParser::OperandType, 8> tileOperands;
bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
bool hasGangStatic = false;
OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
Type intType = builder.getI64Type();
// gang? // gang?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName()))) if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
executionMapping |= OpenACCExecMapping::GANG; executionMapping |= OpenACCExecMapping::GANG;
// vector? // optional gang operand
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName()))) if (succeeded(parser.parseOptionalLParen())) {
executionMapping |= OpenACCExecMapping::VECTOR; if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
hasGangNum = true;
parser.parseColon();
if (parser.parseOperand(gangNum) ||
parser.resolveOperand(gangNum, intType, result.operands)) {
return failure();
}
}
parser.parseOptionalComma();
if (succeeded(
parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
hasGangStatic = true;
parser.parseColon();
if (parser.parseOperand(gangStatic) ||
parser.resolveOperand(gangStatic, intType, result.operands)) {
return failure();
}
}
if (failed(parser.parseRParen()))
return failure();
}
// worker? // worker?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName()))) if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
executionMapping |= OpenACCExecMapping::WORKER; executionMapping |= OpenACCExecMapping::WORKER;
// seq? // optional worker operand
if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName()))) if (succeeded(parser.parseOptionalLParen())) {
executionMapping |= OpenACCExecMapping::SEQ; hasWorkerNum = true;
if (parser.parseOperand(workerNum) ||
parser.resolveOperand(workerNum, intType, result.operands) ||
parser.parseRParen()) {
return failure();
}
}
// vector?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
executionMapping |= OpenACCExecMapping::VECTOR;
// optional vector operand
if (succeeded(parser.parseOptionalLParen())) {
hasVectorLength = true;
if (parser.parseOperand(vectorLength) ||
parser.resolveOperand(vectorLength, intType, result.operands) ||
parser.parseRParen()) {
return failure();
}
}
// tile()?
if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
operandTypes, result)))
return failure();
// private()? // private()?
if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
@ -526,7 +575,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr( builder.getI32VectorAttr(
{static_cast<int32_t>(privateOperands.size()), {static_cast<int32_t>(hasGangNum ? 1 : 0),
static_cast<int32_t>(hasGangStatic ? 1 : 0),
static_cast<int32_t>(hasWorkerNum ? 1 : 0),
static_cast<int32_t>(hasVectorLength ? 1 : 0),
static_cast<int32_t>(tileOperands.size()),
static_cast<int32_t>(privateOperands.size()),
static_cast<int32_t>(reductionOperands.size())})); static_cast<int32_t>(reductionOperands.size())}));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@ -544,17 +598,44 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName()) ? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
.getInt() .getInt()
: 0; : 0;
if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
printer << " " << LoopOp::getGangAttrName();
if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER) if (execMapping & OpenACCExecMapping::GANG) {
printer << " " << LoopOp::getWorkerAttrName(); printer << " " << LoopOp::getGangKeyword();
Value gangNum = op.gangNum();
Value gangStatic = op.gangStatic();
if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR) // Print optional gang operands
printer << " " << LoopOp::getVectorAttrName(); if (gangNum || gangStatic) {
printer << "(";
if (gangNum) {
printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
if (gangStatic)
printer << ", ";
}
if (gangStatic)
printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
printer << ")";
}
}
if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ) if (execMapping & OpenACCExecMapping::WORKER) {
printer << " " << LoopOp::getSeqAttrName(); printer << " " << LoopOp::getWorkerKeyword();
// Print optional worker operand if present
if (Value workerNum = op.workerNum())
printer << "(" << workerNum << ")";
}
if (execMapping & OpenACCExecMapping::VECTOR) {
printer << " " << LoopOp::getVectorKeyword();
// Print optional vector operand if present
if (Value vectorLength = op.vectorLength())
printer << "(" << vectorLength << ")";
}
// tile()?
printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
// private()? // private()?
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);

View File

@ -62,7 +62,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
%c1 = constant 1 : index %c1 = constant 1 : index
acc.parallel { acc.parallel {
acc.loop seq { acc.loop {
scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 { scf.for %arg5 = %c0 to %c10 step %c1 {
@ -76,7 +76,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
} }
} }
acc.yield acc.yield
} } attributes {seq}
acc.yield acc.yield
} }
@ -88,7 +88,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: %{{.*}} = constant 10 : index // CHECK-NEXT: %{{.*}} = constant 10 : index
// CHECK-NEXT: %{{.*}} = constant 1 : index // CHECK-NEXT: %{{.*}} = constant 1 : index
// CHECK-NEXT: acc.parallel { // CHECK-NEXT: acc.parallel {
// CHECK-NEXT: acc.loop seq { // CHECK-NEXT: acc.loop {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@ -102,7 +102,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: acc.yield // CHECK-NEXT: acc.yield
// CHECK-NEXT: } // CHECK-NEXT: } attributes {seq}
// CHECK-NEXT: acc.yield // CHECK-NEXT: acc.yield
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32> // CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
@ -128,7 +128,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
acc.yield acc.yield
} }
acc.loop seq { acc.loop {
// for i = 0 to 10 step 1 // for i = 0 to 10 step 1
// d[x] += c[i] // d[x] += c[i]
scf.for %i = %lb to %c10 step %st { scf.for %i = %lb to %c10 step %st {
@ -138,7 +138,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
store %z, %d[%x] : memref<10xf32> store %z, %d[%x] : memref<10xf32>
} }
acc.yield acc.yield
} } attributes {seq}
} }
acc.yield acc.yield
} }
@ -167,7 +167,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: acc.yield // CHECK-NEXT: acc.yield
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: acc.loop seq { // CHECK-NEXT: acc.loop {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
@ -175,7 +175,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: acc.yield // CHECK-NEXT: acc.yield
// CHECK-NEXT: } // CHECK-NEXT: } attributes {seq}
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: acc.yield // CHECK-NEXT: acc.yield
// CHECK-NEXT: } // CHECK-NEXT: }
@ -184,4 +184,51 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: acc.terminator // CHECK-NEXT: acc.terminator
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10xf32> // CHECK-NEXT: return %{{.*}} : memref<10xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
func @testop() -> () {
%workerNum = constant 1 : i64
%vectorLength = constant 128 : i64
%gangNum = constant 8 : i64
%gangStatic = constant 2 : i64
%tileSize = constant 2 : i64
acc.loop gang worker vector {
}
acc.loop gang(num: %gangNum) {
}
acc.loop gang(static: %gangStatic) {
}
acc.loop worker(%workerNum) {
}
acc.loop vector(%vectorLength) {
}
acc.loop gang(num: %gangNum) worker vector {
}
acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
}
acc.loop tile(%tileSize : i64, %tileSize : i64) {
}
return
}
// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64
// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
// CHECK-NEXT: acc.loop gang worker vector {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
// CHECK-NEXT: }
// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
// CHECK-NEXT: }