forked from OSchip/llvm-project
[mlir][openacc] Add missing attributes and operands for acc.loop
This patch add the missing operands to the acc.loop operation. Only the device_type information is not part of the operation for now. Reviewed By: rriddle, kiranchandramohan Differential Revision: https://reviews.llvm.org/D86753
This commit is contained in:
parent
f862d85807
commit
2bbbcae782
|
@ -31,12 +31,11 @@ namespace acc {
|
|||
/// 2.9.2. gang
|
||||
/// 2.9.3. worker
|
||||
/// 2.9.4. vector
|
||||
/// 2.9.5. seq
|
||||
///
|
||||
/// Value can be combined bitwise to reflect the mapping applied to the
|
||||
/// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
|
||||
/// combined and the final mapping value would be 5 (4 | 1).
|
||||
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 };
|
||||
enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 };
|
||||
|
||||
} // end namespace acc
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -224,6 +224,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
|
|||
|
||||
|
||||
let arguments = (ins OptionalAttr<I64Attr>:$collapse,
|
||||
Optional<AnyInteger>:$gangNum,
|
||||
Optional<AnyInteger>:$gangStatic,
|
||||
Optional<AnyInteger>:$workerNum,
|
||||
Optional<AnyInteger>:$vectorLength,
|
||||
UnitAttr:$loopSeq,
|
||||
UnitAttr:$loopIndependent,
|
||||
UnitAttr:$loopAuto,
|
||||
Variadic<AnyInteger>:$tileOperands,
|
||||
Variadic<AnyType>:$privateOperands,
|
||||
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
|
||||
Variadic<AnyType>:$reductionOperands);
|
||||
|
@ -234,11 +242,16 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getCollapseAttrName() { return "collapse"; }
|
||||
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
|
||||
static StringRef getGangAttrName() { return "gang"; }
|
||||
static StringRef getSeqAttrName() { return "seq"; }
|
||||
static StringRef getVectorAttrName() { return "vector"; }
|
||||
static StringRef getWorkerAttrName() { return "worker"; }
|
||||
static StringRef getIndependentAttrName() { return "independent"; }
|
||||
static StringRef getAutoAttrName() { return "auto"; }
|
||||
static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
|
||||
static StringRef getGangKeyword() { return "gang"; }
|
||||
static StringRef getGangNumKeyword() { return "num"; }
|
||||
static StringRef getGangStaticKeyword() { return "static"; }
|
||||
static StringRef getVectorKeyword() { return "vector"; }
|
||||
static StringRef getWorkerKeyword() { return "worker"; }
|
||||
static StringRef getTileKeyword() { return "tile"; }
|
||||
static StringRef getPrivateKeyword() { return "private"; }
|
||||
static StringRef getReductionKeyword() { return "reduction"; }
|
||||
}];
|
||||
|
|
|
@ -476,32 +476,81 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse acc.loop operation
|
||||
/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`?
|
||||
/// operation := `acc.loop` `gang`? `vector`? `worker`?
|
||||
/// `private` `(` value-list `)`?
|
||||
/// `reduction` `(` value-list `)`?
|
||||
/// region attr-dict?
|
||||
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
Builder &builder = parser.getBuilder();
|
||||
unsigned executionMapping = 0;
|
||||
SmallVector<Type, 8> operandTypes;
|
||||
SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
|
||||
SmallVector<OpAsmParser::OperandType, 8> tileOperands;
|
||||
bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
|
||||
bool hasGangStatic = false;
|
||||
OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
|
||||
Type intType = builder.getI64Type();
|
||||
|
||||
// gang?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName())))
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
|
||||
executionMapping |= OpenACCExecMapping::GANG;
|
||||
|
||||
// vector?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::VECTOR;
|
||||
// optional gang operand
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
|
||||
hasGangNum = true;
|
||||
parser.parseColon();
|
||||
if (parser.parseOperand(gangNum) ||
|
||||
parser.resolveOperand(gangNum, intType, result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
parser.parseOptionalComma();
|
||||
if (succeeded(
|
||||
parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
|
||||
hasGangStatic = true;
|
||||
parser.parseColon();
|
||||
if (parser.parseOperand(gangStatic) ||
|
||||
parser.resolveOperand(gangStatic, intType, result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (failed(parser.parseRParen()))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// worker?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName())))
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
|
||||
executionMapping |= OpenACCExecMapping::WORKER;
|
||||
|
||||
// seq?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName())))
|
||||
executionMapping |= OpenACCExecMapping::SEQ;
|
||||
// optional worker operand
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
hasWorkerNum = true;
|
||||
if (parser.parseOperand(workerNum) ||
|
||||
parser.resolveOperand(workerNum, intType, result.operands) ||
|
||||
parser.parseRParen()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// vector?
|
||||
if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
|
||||
executionMapping |= OpenACCExecMapping::VECTOR;
|
||||
|
||||
// optional vector operand
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
hasVectorLength = true;
|
||||
if (parser.parseOperand(vectorLength) ||
|
||||
parser.resolveOperand(vectorLength, intType, result.operands) ||
|
||||
parser.parseRParen()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// tile()?
|
||||
if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
|
||||
operandTypes, result)))
|
||||
return failure();
|
||||
|
||||
// private()?
|
||||
if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
|
||||
|
@ -526,7 +575,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{static_cast<int32_t>(privateOperands.size()),
|
||||
{static_cast<int32_t>(hasGangNum ? 1 : 0),
|
||||
static_cast<int32_t>(hasGangStatic ? 1 : 0),
|
||||
static_cast<int32_t>(hasWorkerNum ? 1 : 0),
|
||||
static_cast<int32_t>(hasVectorLength ? 1 : 0),
|
||||
static_cast<int32_t>(tileOperands.size()),
|
||||
static_cast<int32_t>(privateOperands.size()),
|
||||
static_cast<int32_t>(reductionOperands.size())}));
|
||||
|
||||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
|
@ -544,17 +598,44 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
|
|||
? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
|
||||
.getInt()
|
||||
: 0;
|
||||
if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
|
||||
printer << " " << LoopOp::getGangAttrName();
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER)
|
||||
printer << " " << LoopOp::getWorkerAttrName();
|
||||
if (execMapping & OpenACCExecMapping::GANG) {
|
||||
printer << " " << LoopOp::getGangKeyword();
|
||||
Value gangNum = op.gangNum();
|
||||
Value gangStatic = op.gangStatic();
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR)
|
||||
printer << " " << LoopOp::getVectorAttrName();
|
||||
// Print optional gang operands
|
||||
if (gangNum || gangStatic) {
|
||||
printer << "(";
|
||||
if (gangNum) {
|
||||
printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
|
||||
if (gangStatic)
|
||||
printer << ", ";
|
||||
}
|
||||
if (gangStatic)
|
||||
printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
|
||||
printer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ)
|
||||
printer << " " << LoopOp::getSeqAttrName();
|
||||
if (execMapping & OpenACCExecMapping::WORKER) {
|
||||
printer << " " << LoopOp::getWorkerKeyword();
|
||||
|
||||
// Print optional worker operand if present
|
||||
if (Value workerNum = op.workerNum())
|
||||
printer << "(" << workerNum << ")";
|
||||
}
|
||||
|
||||
if (execMapping & OpenACCExecMapping::VECTOR) {
|
||||
printer << " " << LoopOp::getVectorKeyword();
|
||||
|
||||
// Print optional vector operand if present
|
||||
if (Value vectorLength = op.vectorLength())
|
||||
printer << "(" << vectorLength << ")";
|
||||
}
|
||||
|
||||
// tile()?
|
||||
printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
|
||||
|
||||
// private()?
|
||||
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
|
||||
|
|
|
@ -62,7 +62,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
|
|||
%c1 = constant 1 : index
|
||||
|
||||
acc.parallel {
|
||||
acc.loop seq {
|
||||
acc.loop {
|
||||
scf.for %arg3 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg4 = %c0 to %c10 step %c1 {
|
||||
scf.for %arg5 = %c0 to %c10 step %c1 {
|
||||
|
@ -76,7 +76,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
|
|||
}
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
} attributes {seq}
|
||||
acc.yield
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
|
|||
// CHECK-NEXT: %{{.*}} = constant 10 : index
|
||||
// CHECK-NEXT: %{{.*}} = constant 1 : index
|
||||
// CHECK-NEXT: acc.parallel {
|
||||
// CHECK-NEXT: acc.loop seq {
|
||||
// CHECK-NEXT: acc.loop {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
|
@ -102,7 +102,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } attributes {seq}
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
|
||||
|
@ -128,7 +128,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
|
|||
acc.yield
|
||||
}
|
||||
|
||||
acc.loop seq {
|
||||
acc.loop {
|
||||
// for i = 0 to 10 step 1
|
||||
// d[x] += c[i]
|
||||
scf.for %i = %lb to %c10 step %st {
|
||||
|
@ -138,7 +138,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
|
|||
store %z, %d[%x] : memref<10xf32>
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
} attributes {seq}
|
||||
}
|
||||
acc.yield
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop seq {
|
||||
// CHECK-NEXT: acc.loop {
|
||||
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
|
@ -175,7 +175,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
|
|||
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } attributes {seq}
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.yield
|
||||
// CHECK-NEXT: }
|
||||
|
@ -184,4 +184,51 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
|
|||
// CHECK-NEXT: acc.terminator
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @testop() -> () {
|
||||
%workerNum = constant 1 : i64
|
||||
%vectorLength = constant 128 : i64
|
||||
%gangNum = constant 8 : i64
|
||||
%gangStatic = constant 2 : i64
|
||||
%tileSize = constant 2 : i64
|
||||
acc.loop gang worker vector {
|
||||
}
|
||||
acc.loop gang(num: %gangNum) {
|
||||
}
|
||||
acc.loop gang(static: %gangStatic) {
|
||||
}
|
||||
acc.loop worker(%workerNum) {
|
||||
}
|
||||
acc.loop vector(%vectorLength) {
|
||||
}
|
||||
acc.loop gang(num: %gangNum) worker vector {
|
||||
}
|
||||
acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
|
||||
}
|
||||
acc.loop tile(%tileSize : i64, %tileSize : i64) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64
|
||||
// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
|
||||
// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
|
||||
// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
|
||||
// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
|
||||
// CHECK-NEXT: acc.loop gang worker vector {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
|
||||
// CHECK-NEXT: }
|
||||
|
|
Loading…
Reference in New Issue