forked from OSchip/llvm-project
Standardize `linalg.generic` on `args_in`/`args_out` instead of `inputCount`/`outputCount`
This also fixes the outdated use of `n_views` in the documentation. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D83795
This commit is contained in:
parent
911fcf382f
commit
941fecc536
|
@ -562,7 +562,8 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
doc = "C(m, n) += A(m, k) * B(k, n)",
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
```
|
||||
|
@ -634,7 +635,7 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
let builders = [
|
||||
OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
|
||||
"ValueRange args, int64_t inputCount, int64_t outputCount, "
|
||||
"ValueRange args, int64_t argsIn, int64_t argsOut, "
|
||||
"ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">
|
||||
];
|
||||
|
@ -689,7 +690,8 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
doc = "C(m, n) += A(m, k) * B(k, n)",
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
```
|
||||
|
@ -768,7 +770,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
let builders = [
|
||||
OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
|
||||
"ValueRange args, int64_t inputCount, int64_t outputCount, "
|
||||
"ValueRange args, int64_t argsIn, int64_t argsOut, "
|
||||
"ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> "
|
||||
"= nullptr">
|
||||
|
|
|
@ -72,12 +72,11 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
|
||||
void GenericOp::build(
|
||||
OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
|
||||
ValueRange args, int64_t inputCount, int64_t outputCount,
|
||||
ValueRange args, int64_t argsIn, int64_t argsOut,
|
||||
ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
build(builder, result, resultTypes, args,
|
||||
builder.getI64IntegerAttr(inputCount),
|
||||
builder.getI64IntegerAttr(outputCount),
|
||||
build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn),
|
||||
builder.getI64IntegerAttr(argsOut),
|
||||
builder.getAffineMapArrayAttr(indexingMaps),
|
||||
builder.getStrArrayAttr(iteratorTypes),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
|
@ -96,13 +95,12 @@ void GenericOp::build(
|
|||
|
||||
void IndexedGenericOp::build(
|
||||
OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
|
||||
ValueRange args, int64_t inputCount, int64_t outputCount,
|
||||
ValueRange args, int64_t argsIn, int64_t argsOut,
|
||||
ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
|
||||
bodyBuild) {
|
||||
build(builder, result, resultTypes, args,
|
||||
builder.getI64IntegerAttr(inputCount),
|
||||
builder.getI64IntegerAttr(outputCount),
|
||||
build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn),
|
||||
builder.getI64IntegerAttr(argsOut),
|
||||
builder.getAffineMapArrayAttr(indexingMaps),
|
||||
builder.getStrArrayAttr(iteratorTypes),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
|
|
Loading…
Reference in New Issue