forked from OSchip/llvm-project
[mlir][linalg] Adapt yaml codegen to support scalar parameters.
The patch updates the C++ yaml code generation to support scalar operands as added in https://reviews.llvm.org/D104220. Differential Revision: https://reviews.llvm.org/D104224
This commit is contained in:
parent
073e7a08e8
commit
ff2ef4d684
|
@ -19,7 +19,7 @@ package, if available, to avoid building.
|
|||
|
||||
```shell
|
||||
# Dump the `core_named_ops.py` module as YAML.
|
||||
python -m python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops
|
||||
python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops
|
||||
```
|
||||
|
||||
The tool is meant for use during both development and runtime, but not as
|
||||
|
|
|
@ -11,21 +11,21 @@ metadata: !LinalgOpMetadata
|
|||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T2
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
element_type_var: U
|
||||
type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
|
@ -73,21 +73,21 @@ metadata: !LinalgOpMetadata
|
|||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T2
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
|
||||
element_type_var: U
|
||||
type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
|
||||
|
@ -136,21 +136,21 @@ metadata: !LinalgOpMetadata
|
|||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: y
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T2
|
||||
- !LinalgOperandDefConfig
|
||||
name: x
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0)>
|
||||
element_type_var: U
|
||||
type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
|
@ -197,21 +197,21 @@ metadata: !LinalgOpMetadata
|
|||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: y
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1, s0)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T2
|
||||
- !LinalgOperandDefConfig
|
||||
name: x
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0)>
|
||||
element_type_var: U
|
||||
type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d1)>
|
||||
|
@ -258,21 +258,21 @@ metadata: !LinalgOpMetadata
|
|||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0] -> (s0)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
usage: input
|
||||
shape: affine_map<()[s0] -> (s0)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
type_var: T2
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
usage: output
|
||||
shape: affine_map<()[s0] -> ()>
|
||||
element_type_var: U
|
||||
type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0)[s0] -> (d0)>
|
||||
|
@ -319,18 +319,30 @@ metadata: !LinalgOpMetadata
|
|||
and runs them in parallel. The seed operand and the indices of the data
|
||||
element seed the random number generation. The min and max operands limit
|
||||
the range of the generated random numbers.
|
||||
|
||||
Note: The captures are hard-coded till there is capture support on the C++
|
||||
side.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: min
|
||||
usage: input
|
||||
type_var: F64
|
||||
- !LinalgOperandDefConfig
|
||||
name: max
|
||||
usage: input
|
||||
type_var: F64
|
||||
- !LinalgOperandDefConfig
|
||||
name: seed
|
||||
usage: input
|
||||
type_var: I32
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T
|
||||
type_var: T
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> ()>
|
||||
- affine_map<(d0, d1)[s0, s1] -> ()>
|
||||
- affine_map<(d0, d1)[s0, s1] -> ()>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
|
@ -401,11 +413,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_index: 0
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '42 : i64'
|
||||
scalar_arg: seed
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: I32
|
||||
|
@ -439,17 +447,9 @@ structured_op: !LinalgStructuredOpConfig
|
|||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '1000 : i64'
|
||||
scalar_arg: max
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '-1000 : i64'
|
||||
scalar_arg: min
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: F64
|
||||
|
@ -457,8 +457,4 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_const: '2.3283063999999999E-10 : f64'
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '-1000 : i64'
|
||||
scalar_arg: min
|
||||
|
|
|
@ -30,16 +30,13 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_fill_rng_2d_f32
|
||||
// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>)
|
||||
// CHECK-DAG: %[[MIN:.+]] = constant -1000 : i64
|
||||
// CHECK-DAG: %[[MAX:.+]] = constant 1000 : i64
|
||||
// CHECK-DAG: %[[SEED:.+]] = constant 42 : i32
|
||||
// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: f32
|
||||
// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
|
||||
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
|
||||
// CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
|
||||
|
@ -50,27 +47,24 @@ func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
|||
// CHECK-DAG: %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32
|
||||
// CHECK-DAG: %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32
|
||||
// Skip random number computation for the second index.
|
||||
// CHECK-DAG: %[[MIN_CAST1:.+]] = sitofp %[[MIN]] : i64 to f64
|
||||
// CHECK-DAG: %[[MAX_CAST:.+]] = sitofp %[[MAX]] : i64 to f64
|
||||
// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX_CAST]], %[[MIN_CAST1]] : f64
|
||||
// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
|
||||
// CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
|
||||
// CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
|
||||
// CHECK-DAG: %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
|
||||
// CHECK-DAG: %[[MIN_CAST2:.+]] = sitofp %[[MIN]] : i64 to f64
|
||||
// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN_CAST2]] : f64
|
||||
// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN]] : f64
|
||||
// CHECK-DAG: %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32
|
||||
// CHECK-NEXT: linalg.yield %[[VAL6]] : f32
|
||||
// CHECK-NEXT: -> tensor<16x32xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
return %0: tensor<16x32xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_fill_rng_2d_i32
|
||||
// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>)
|
||||
// CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: i32
|
||||
// Verifies floating point to integer cast.
|
||||
// CHECK: %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32
|
||||
// CHECK-NEXT: linalg.yield %[[VAL6]] : i32
|
||||
|
|
|
@ -19,11 +19,11 @@ metadata: !LinalgOpMetadata
|
|||
Detailed description.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T
|
||||
type_var: T
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
|
@ -58,7 +58,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
# ODS-NEXT: }];
|
||||
|
||||
# ODS: let arguments =
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs
|
||||
|
||||
# ODS: let builders =
|
||||
|
@ -103,18 +103,23 @@ metadata: !LinalgOpMetadata
|
|||
Detailed description.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
- !LinalgOperandDefConfig
|
||||
name: value
|
||||
usage: input
|
||||
type_var: T
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T
|
||||
- !<LinalgTensorDef>
|
||||
shape: affine_map<()[s0, s1] -> (s1, s0)>
|
||||
type_var: T
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T
|
||||
type_var: T
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> ()>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
iterator_types:
|
||||
|
@ -124,15 +129,23 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_arg: I
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: value
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
|
||||
# IMPL-LABEL: Test2Op::iterator_types()
|
||||
# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
|
||||
|
||||
# IMPL: Test2Op::indexing_maps()
|
||||
# IMPL: "affine_map<(d0, d1)[s0, s1] -> ()>"
|
||||
# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
|
||||
# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
|
||||
|
||||
# IMPL: void Test2Op::regionBuilder(
|
||||
# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
|
||||
# IMPL: yields.push_back(block.getArgument(0));
|
||||
|
||||
# IMPL: = helper.applyfn__add(block.getArgument(0), block.getArgument(1));
|
||||
|
|
|
@ -131,6 +131,33 @@ def test_matmul_generic():
|
|||
test_matmul_generic()
|
||||
|
||||
|
||||
def test_fill_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f64 = F64Type.get()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
|
||||
def fill_on_buffers(min, max, seed, out):
|
||||
linalg.fill_rng_2d(min, max, seed, outs=[out])
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, fill_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result i32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_int_p = ctypes.c_int * 1
|
||||
res = c_int_p(-1)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# CHECK: RESULT: -480
|
||||
|
||||
|
||||
test_fill_builtin()
|
||||
|
||||
|
||||
def test_fill_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
|
|
|
@ -62,17 +62,13 @@ struct SerializedAffineMap {
|
|||
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
||||
};
|
||||
|
||||
enum class LinalgTensorUsageDef {
|
||||
input,
|
||||
output,
|
||||
temporary,
|
||||
};
|
||||
enum class LinalgOperandDefUsage { input, output };
|
||||
|
||||
struct LinalgTensorDef {
|
||||
struct LinalgOperandDef {
|
||||
std::string name;
|
||||
LinalgTensorUsageDef usage;
|
||||
SerializedAffineMap shape;
|
||||
std::string elementTypeVar;
|
||||
LinalgOperandDefUsage usage;
|
||||
Optional<SerializedAffineMap> shape;
|
||||
std::string typeVar;
|
||||
};
|
||||
|
||||
enum class LinalgIteratorTypeDef {
|
||||
|
@ -114,10 +110,10 @@ struct ScalarAssign {
|
|||
};
|
||||
|
||||
struct LinalgStructuredOpConfig {
|
||||
SmallVector<LinalgTensorDef> args;
|
||||
SmallVector<LinalgOperandDef> args;
|
||||
LinalgIndexingMapsConfig indexingMaps;
|
||||
SmallVector<LinalgIteratorTypeDef> iteratorTypes;
|
||||
SmallVector<ScalarAssign, 2> assignments;
|
||||
std::vector<ScalarAssign> assignments;
|
||||
};
|
||||
|
||||
struct LinalgOpConfig {
|
||||
|
@ -131,7 +127,7 @@ struct LinalgOpConfig {
|
|||
// Mapping traits.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef)
|
||||
LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
|
||||
LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
|
||||
LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
|
||||
LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
|
||||
|
@ -153,8 +149,8 @@ struct MappingTraits<LinalgOpConfig> {
|
|||
};
|
||||
|
||||
/// A structured op models (at most) a single contraction by modeling
|
||||
/// - A list of named arguments (`LinalgTensorDef`), which can be inputs,
|
||||
/// outputs, or temporaries.
|
||||
/// - A list of named arguments (`LinalgOperandDef`), which can be inputs or
|
||||
/// outputs.
|
||||
/// - List of indexing maps (see `LinalgIndexingMaps`).
|
||||
/// - Iterator types (see `LinalgIteratorTypeDef`).
|
||||
/// - List of scalar level assignment (see `ScalarAssign`).
|
||||
|
@ -168,31 +164,30 @@ struct MappingTraits<LinalgStructuredOpConfig> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Maps a named tensor-argument to an operation, consisting of:
|
||||
/// Maps a named tensor- or scalar-argument to an operation, consisting of:
|
||||
/// - `name`: Must be unique within the operation.
|
||||
/// - `usage`: How the argument is used (input, output, etc).
|
||||
/// - `shape`: An AffineMap from all op symbols to the specific shape
|
||||
/// of this argument. Each shape must be normalized over the same list of
|
||||
/// symbols and have no dimension inputs.
|
||||
/// - `element_type_var`: The symbolic type variable that binds to the scalar
|
||||
/// element type of this TensorDef.
|
||||
/// - `shape`: An optional AffineMap from all op symbols to the shape of the
|
||||
/// argument. Only tensor-arguments have a shape. Each shape must be
|
||||
/// normalized over the same list of symbols and have no dimension inputs.
|
||||
/// - `type_var`: The symbolic type variable that binds to the element or self
|
||||
/// type of the tensor- or scalar-argument, respectively.
|
||||
template <>
|
||||
struct MappingTraits<LinalgTensorDef> {
|
||||
static void mapping(IO &io, LinalgTensorDef &info) {
|
||||
struct MappingTraits<LinalgOperandDef> {
|
||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||
io.mapRequired("name", info.name);
|
||||
io.mapRequired("usage", info.usage);
|
||||
io.mapRequired("shape", info.shape);
|
||||
io.mapRequired("element_type_var", info.elementTypeVar);
|
||||
io.mapOptional("shape", info.shape);
|
||||
io.mapRequired("type_var", info.typeVar);
|
||||
}
|
||||
};
|
||||
|
||||
/// Usage enum for a named argument.
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<LinalgTensorUsageDef> {
|
||||
static void enumeration(IO &io, LinalgTensorUsageDef &value) {
|
||||
io.enumCase(value, "input", LinalgTensorUsageDef::input);
|
||||
io.enumCase(value, "output", LinalgTensorUsageDef::output);
|
||||
io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary);
|
||||
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
||||
io.enumCase(value, "input", LinalgOperandDefUsage::input);
|
||||
io.enumCase(value, "output", LinalgOperandDefUsage::output);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -229,7 +224,7 @@ struct MappingTraits<LinalgIndexingMapsConfig> {
|
|||
};
|
||||
|
||||
/// Models an assignment to a named output.
|
||||
/// - The `arg` name must match a named output or temporary.
|
||||
/// - The `arg` name must match a named output.
|
||||
/// - The `value` is a scalar expression for computing the value to
|
||||
/// assign (see `ScalarExpression`).
|
||||
template <>
|
||||
|
@ -366,7 +361,7 @@ static std::string interleaveToString(Container &container,
|
|||
}
|
||||
|
||||
static Optional<int>
|
||||
findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
|
||||
findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
|
||||
for (auto it : llvm::enumerate(args)) {
|
||||
if (it.value().name == name)
|
||||
return it.index();
|
||||
|
@ -376,7 +371,7 @@ findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
|
|||
|
||||
// Try to map the TypeVar to a predefined or an argument type.
|
||||
static Optional<std::string>
|
||||
findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
|
||||
findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
|
||||
// Handle all predefined types.
|
||||
if (typeVar == "I32")
|
||||
return std::string("helper.getIntegerType(32)");
|
||||
|
@ -389,7 +384,7 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
|
|||
|
||||
// Search all argument types.
|
||||
for (auto it : llvm::enumerate(args)) {
|
||||
if (it.value().elementTypeVar == typeVar)
|
||||
if (it.value().typeVar == typeVar)
|
||||
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
||||
.str();
|
||||
}
|
||||
|
@ -397,8 +392,8 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
|
|||
return None;
|
||||
}
|
||||
|
||||
static ScalarAssign *
|
||||
findAssignment(StringRef name, SmallVectorImpl<ScalarAssign> &assignments) {
|
||||
static ScalarAssign *findAssignment(StringRef name,
|
||||
std::vector<ScalarAssign> &assignments) {
|
||||
for (auto &assign : assignments) {
|
||||
if (assign.arg == name)
|
||||
return &assign;
|
||||
|
@ -445,7 +440,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
|||
/*extraInterfaces=*/[{2}])> {
|
||||
{3}
|
||||
let arguments = (ins
|
||||
Variadic<AnyShaped>:$inputs,
|
||||
Variadic<AnyType>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs{4}
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||
|
@ -467,7 +462,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
|||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs)/*, TODO: support captures*/);
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
|
@ -485,7 +480,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
|||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs)/*, TODO: support captures*/);
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||
|
@ -500,7 +495,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
|||
];
|
||||
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
|
||||
let parser = [{{
|
||||
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
|
||||
return ::parseNamedStructuredOp<{0}>(parser, result);
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
|
||||
|
@ -768,9 +763,8 @@ void {0}::regionBuilder(
|
|||
size_t generatedAssignmentCount = 0;
|
||||
int localCounter = 0;
|
||||
SmallVector<std::string> stmts;
|
||||
for (LinalgTensorDef &arg : args) {
|
||||
if (arg.usage != LinalgTensorUsageDef::output &&
|
||||
arg.usage != LinalgTensorUsageDef::temporary)
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::output)
|
||||
continue;
|
||||
|
||||
// Find the assignment that correlates with the argument.
|
||||
|
|
Loading…
Reference in New Issue