[mlir] Add shaped container component type interface

Summary:
* Add shaped container type interface which allows infering the shape, element
  type and attribute of shaped container type separately. Show usage by way of
  tensor type inference trait which combines the shape & element type in
  infering a tensor type;
  - All components need not be specified;
  - Attribute is added to allow for layout attribute that was previously
    discussed;
* Expand the test driver to make it easier to test new creation instances
  (adding new operands or ops with attributes or regions would trigger build
  functions/type inference methods);
  - The verification part will be moved out of the test and to verify method
    instead of ops implementing the type inference interface in a follow up;
* Add MLIRContext as arg to possible to create type for ops without arguments,
  region or location;
* Also move out the section in OpDefinitions doc to separate ShapeInference doc
  where the shape function requirements can be captured;
  - Part of this would move to the shape dialect and/or shape dialect ops be
    included as subsection of this doc;
* Update ODS's variable usage to match camelBack format for builder,
  state and arg variables;
  - I could have split this out, but I had to make some changes around
    these and the inconsistency bugged me :)

Differential Revision: https://reviews.llvm.org/D72432
This commit is contained in:
Jacques Pienaar 2020-01-08 18:48:38 -08:00
parent 34ba96a3d4
commit fa26a37d36
15 changed files with 397 additions and 164 deletions

View File

@ -429,14 +429,14 @@ The following builders are generated:
```c++
// All result-types/operands/attributes have one aggregate parameter.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
static void build(Builder *odsBuilder, OperationState &odsState,
ArrayRef<Type> resultTypes,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
// Each result-type/operand/attribute has a separate parameter. The parameters
// for attributes are of mlir::Attribute types.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
static void build(Builder *odsBuilder, OperationState &odsState,
Type i32_result, Type f32_result, ...,
Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
@ -445,20 +445,20 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state,
// for attributes are raw values unwrapped with mlir::Attribute instances.
// (Note that this builder will not always be generated. See the following
// explanation for more details.)
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
static void build(Builder *odsBuilder, OperationState &odsState,
Type i32_result, Type f32_result, ...,
Value i32_operand, Value f32_operand, ...,
APInt i32_attr, StringRef f32_attr, ...);
// Each operand/attribute has a separate parameter but result type is aggregate.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
static void build(Builder *odsBuilder, OperationState &odsState,
ArrayRef<Type> resultTypes,
Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// All operands/attributes have aggregate parameters.
// Generated if InferTypeOpInterface interface is specified.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
static void build(Builder *odsBuilder, OperationState &odsState,
ValueRange operands,
ArrayRef<NamedAttribute> attributes);
@ -1099,7 +1099,7 @@ requirements that were desirable:
* The op's traits (e.g., commutative) are modelled along with the op in the
registry.
* The op's operand/return type constraints are modelled along with the op in
the registry (see [Shape inference](#shape-inference) discussion below),
the registry (see [Shape inference](ShapeInference.md) discussion below),
this allows (e.g.) optimized concise syntax in textual dumps.
* Behavior of the op is documented along with the op with a summary and a
description. The description is written in markdown and extracted for
@ -1156,49 +1156,6 @@ tfl.add $lhs, $rhs {fused_activation_function: $fused_activation_function}: ${ty
Printing is effectively the inverse of the parsing function generated with the
mnemonic string serving as a template.
### Shape inference
Type constraints are along (at least) three axis: 1) elemental type, 2) rank
(including static or dynamic), 3) dimensions. While some ops have no compile
time fixed shape (e.g., output shape is dictated by data) we could still have
some knowledge of constraints/bounds in the system for that op (e.g., the output
of a `tf.where` is at most the size of the input data). And so there are
additional valuable constraints that could be captured even without full
knowledge.
Initially the shape inference will be declaratively specified using:
* Constraint on the operands of an operation directly. For example
constraining the input type to be tensor/vector elements or that the
elemental type be of a specific type (e.g., output of sign is of elemental
type `i1`) or class (e.g., float like).
* Constraints across operands and results of an operation. For example,
enabling specifying equality constraints on type/constituents of a type
(shape and elemental type) between operands and results (e.g., the output
type of an add is the same as those of the input operands).
In general there is an input/output transfer function which maps the inputs to
the outputs (e.g., given input X and Y [or slices thereof] with these sizes, the
output is Z [or this slice thereof]). Such a function could be used to determine
the output type (shape) for given input type (shape).
But shape functions are determined by attributes and could be arbitrarily
complicated with a wide-range of specification possibilities. Equality
relationships are common (e.g., the elemental type of the output matches the
primitive type of the inputs, both inputs have exactly the same type [primitive
type and shape]) and so these should be easy to specify. Algebraic relationships
would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
is `[n+n, m]` matrix), while some ops only have defined shapes under certain
cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if
`b == c`). As ops are also verified, the shape inference need only specify rules
for the allowed cases (e.g., shape inference for matmul can ignore the case
where `b != c`), which would simplify type constraint specification.
Instead of specifying an additional mechanism to specify a shape transfer
function, the reference implementation of the operation will be used to derive
the shape function. The reference implementation is general and can support the
arbitrary computations needed to specify output shapes.
[TableGen]: https://llvm.org/docs/TableGen/index.html
[TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html
[TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html

View File

@ -0,0 +1,72 @@
# Shape inference
Shape inference as discussed here is considered a specific instance of type
inference for [ShapedType][ShapedType]. Type constraints are along (at least)
three axis: 1) elemental type, 2) rank (including static or dynamic), 3)
dimensions. While some operations have no compile time fixed shape (e.g., output
shape is dictated by data) we could still have some knowledge of
constraints/bounds in the system for that operation (e.g., the output of a
`tf.where` is at most the size of the input data). That is, there are additional
valuable constraints that could be captured even without full knowledge of the
shape.
Type inference is currently modelled executionally for op creation using the
[`InferTypeOpInterface`][InferTypeOpInterface], while
`InferShapedTypeOpInterface` is used to implement the shape and element type
inference. The return type can often be deduced from the deduced return shape
and elemental type (queryable from `InferShapedTypeOpInterface`) and so type
inference for tensor types can be implemented with `InferShapedTypeOpInterface`.
## Shape functions
The C++ interfaces are the base mechanism whereby shape inference is queried and
executed, but not the intended way to specify shape constraints in general.
Initially the shape inference will be declaratively specified using:
* Constraints on the operands of an operation directly. For example
constraining the input type to be tensor/vector elements or that the
elemental type be of a specific type (e.g., output of computing the size
of a value is of elemental type `i1`) or class (e.g., float like).
* Constraints across operands and results of an operation.
- For example, specifying equality constraints on type/constituents of a
type (shape and elemental type) between operands and results (e.g., the
output type of an add is the same as those of the input operands).
NOTE: The C++ shape functions are an intermediate step until the shape dialect
is more full-fledged, at which point the C++ functions should become the
exceptional case.
## Testing
Shape inference is currently tested alongside type inference by
`TestReturnTypeDriver` in the test dialect. The driver performs two checks:
1. Verification that the return types specified matches the infered types. This
explicit check will be removed and made part of Op verificaton instead.
2. Test the creation of Ops without specifying the return type explicitly in
function `testCreateFunctions` by creating new binary Ops (Op classes
specified in `TestReturnTypeDriver`) using 1) all operands to
`testCreateFunctions` as both operands, and 2) using combinations of input
operands of the function.
## WIP/Future considerations
Shape functions are determined by attributes and could be arbitrarily
complicated with a wide-range of specification possibilities. Equality
relationships are common (e.g., the elemental type of the output matches the
primitive type of the inputs, both inputs have exactly the same type [primitive
type and shape]) and so these should be easy to specify. Algebraic relationships
would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
is `[n+n, m]` matrix), while some ops only have defined shapes under certain
cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b
== c`).
Instead of specifying an additional mechanism to specify a shape transfer
function, the reference implementation of the operation will be used to derive
the shape function. The reference implementation is general and can support the
arbitrary computations needed to specify output shapes.
[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/Analysis/InferTypeOpInterface.td
[ShapedType]: https://github.com/llvm/llvm-project/tree/master/mlir/include/mlir/IR/StandardTypes.h

View File

@ -17,28 +17,100 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
/// ShapedTypeComponents that represents the components of a ShapedType.
/// The components consist of
/// - A ranked or unranked shape with the dimension specification match those
/// of ShapeType's getShape() (e.g., dynamic dimension represented using
/// ShapedType::kDynamicSize)
/// - A element type, may be unset (nullptr)
/// - A attribute, may be unset (nullptr)
/// Used by ShapedType type inferences.
class ShapedTypeComponents {
/// Internal storage type for shape.
using ShapeStorageT = SmallVector<int64_t, 3>;
public:
/// Default construction is an unranked shape.
ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){};
template <typename Arg, typename = typename std::enable_if_t<
std::is_constructible<ShapeStorageT, Arg>::value>>
ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
Attribute attr = nullptr)
: dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType),
attr(attr) {}
ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
Attribute attr = nullptr)
: dims(vec.begin(), vec.end()), ranked(true), elementType(elementType),
attr(attr) {}
/// Return the dimensions of the shape.
/// Requires: shape is ranked.
ArrayRef<int64_t> getDims() const {
assert(ranked && "requires ranked shape");
return dims;
}
/// Return whether the shape has a rank.
bool hasRank() const { return ranked; };
/// Return the element type component.
Type getElementType() const { return elementType; };
/// Return the raw attribute component.
Attribute getAttribute() const { return attr; };
private:
ShapeStorageT dims;
bool ranked;
Type elementType;
Attribute attr;
};
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
namespace OpTrait {
template <typename ConcreteType>
class TypeOpInterfaceDefault
: public TraitBase<ConcreteType, TypeOpInterfaceDefault> {
public:
/// Returns whether two arrays are equal as strongest check for compatibility
/// by default.
static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
return lhs == rhs;
};
};
} // namespace OpTrait
namespace detail {
// Helper function to infer return tensor returns types given element and shape
// inference function.
//
// TODO: Consider generating typedefs for trait member functions if this usage
// becomes more common.
LogicalResult inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes);
} // namespace detail
namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the infered
/// shape and elemental types.
/// Requires: Op implements functions of InferShapedTypeOpInterface.
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
static LogicalResult
inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, ArrayRef<NamedAttribute> attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, regions, inferedReturnTypes);
}
};
} // namespace OpTrait
} // namespace mlir
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_

View File

@ -22,9 +22,8 @@ include "mlir/IR/OpBase.td"
// mismatch).
def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that could be used during op construction, verification or
type inference.
Interface to infer the return types for an operation that could be used
during op construction, verification or type inference.
}];
let methods = [
@ -38,7 +37,8 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"inferReturnTypes",
/*args=*/(ins "Optional<Location>":$location,
/*args=*/(ins "MLIRContext*":$context,
"Optional<Location>":$location,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes,
"RegionRange":$regions,
@ -62,4 +62,38 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
];
}
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
let description = [{
Interface to infer the components of a ShapedType returned by an operation
that could be used during op construction, verification or shape inference.
The components consists of element type, shape and raw attribute.
}];
let methods = [
StaticInterfaceMethod<
/*desc=*/[{Infer the components of return type of shape containter.
The method takes an optional location which, if set, will be used to
report errors on. The operands and attributes correspond to those with
which an Operation would be created (e.g., as used in Operation::create)
and the regions of the op.
Unknown (e.g., unranked) shape and nullptrs for element type and attribute
may be returned by this function while returning success. E.g., partial
population of components is not error condition.
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"inferReturnTypeComponents",
/*args=*/(ins "MLIRContext*":$context,
"Optional<Location>":$location,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes,
"RegionRange":$regions,
"SmallVectorImpl<ShapedTypeComponents>&":
$inferedReturnShapes)
>,
];
}
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@ -1539,7 +1539,7 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// following signatures:
//
// ```c++
// static void build(Builder *, OperationState &tblgen_state,
// static void build(Builder *, OperationState &odsState,
// Type <result0-name>, Type <result1-name>, ...,
// Value <arg0-name>, Value <arg1-name>, ...,
// Attribute <attr0-name>, Attribute <attr1-name>, ...);
@ -1547,7 +1547,7 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// * where the attributes follow the same declaration order as in the op.
//
// ```c++
// static void build(Builder *, OperationState &tblgen_state,
// static void build(Builder *, OperationState &odsState,
// ArrayRef<Type> resultTypes,
// ArrayRef<Value> operands,
// ArrayRef<NamedAttribute> attributes);

View File

@ -12,11 +12,36 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/InferTypeOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
namespace mlir {
#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
} // namespace mlir
LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes, regions,
retComponents)))
return failure();
for (auto shapeAndType : retComponents) {
assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
if (shapeAndType.hasRank())
inferedReturnTypes.push_back(RankedTensorType::get(
shapeAndType.getDims(), shapeAndType.getElementType()));
else
inferedReturnTypes.push_back(
UnrankedTensorType::get(shapeAndType.getElementType()));
}
return success();
}

View File

@ -295,7 +295,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
}
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ValueRange operands,
MLIRContext *, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
@ -307,6 +307,30 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
return success();
}
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
// Create return type consisting of the first element of each shape of the
// input operands or unknown for unranked operand.
std::vector<int64_t> shape;
shape.reserve(operands.size());
for (auto operandType : operands.getTypes()) {
if (auto sval = operandType.dyn_cast<ShapedType>()) {
if (sval.hasRank())
shape.push_back(sval.getShape().front());
else
shape.push_back(ShapedType::kDynamicSize);
} else {
return emitOptionalError(location, "only shaped type operands allowed");
}
}
inferedComponents.reserve(1);
auto type = IntegerType::get(17, context);
inferedComponents.emplace_back(shape, type);
return success();
}
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;

View File

@ -402,6 +402,21 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
let results = (outs AnyTensor);
}
def InferTensorType : NativeOpTrait<"InferTensorType">;
def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if",
[
// Op implements infer type op interface.
InferTypeOpInterface,
// The op will have methods implementing the ShapedType type infer interface.
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
// The op produces tensors and will use the ShapedType type infer interface
// along with knowledge that it is producing Tensors to infer shape.
InferTensorType
]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),

View File

@ -58,50 +58,71 @@ static mlir::PassRegistration<TestPatternDriver>
//===----------------------------------------------------------------------===//
namespace {
struct ReturnTypeOpMatch : public RewritePattern {
ReturnTypeOpMatch(MLIRContext *ctx)
: RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
}
// Generate ops for each instance where the type can be succesfully infered.
template <typename OpTy>
static void invokeCreateWithInferedReturnType(Operation *op) {
auto *context = op->getContext();
auto fop = op->getParentOfType<FuncOp>();
auto location = UnknownLoc::get(context);
OpBuilder b(op);
b.setInsertionPointAfter(op);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
SmallVector<Value, 4> values(op->getOperands());
// Use permutations of 2 args as operands.
assert(fop.getNumArguments() >= 2);
for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
for (int j = 0; j < e; ++j) {
std::array<Value, 2> values = {fop.getArgument(i), fop.getArgument(j)};
SmallVector<Type, 2> inferedReturnTypes;
if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
op->getAttrs(), op->getRegions(),
inferedReturnTypes)))
return matchFailure();
SmallVector<Type, 1> resultTypes(op->getResultTypes());
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
return op->emitOpError(
"inferred type incompatible with return type of operation"),
matchFailure();
// TODO(jpienaar): Split this out to make the test more focused.
// Create new op with unknown location to verify building with
// InferTypeOpInterface is triggered.
auto fop = op->getParentOfType<FuncOp>();
if (values[0] == fop.getArgument(0)) {
// Use the 2nd function argument if the first function argument is used
// when constructing the new op so that a new return type is inferred.
values[0] = fop.getArgument(1);
values[1] = fop.getArgument(1);
if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values,
op->getAttrs(), op->getRegions(),
inferedReturnTypes))) {
OperationState state(location, OpTy::getOperationName());
// TODO(jpienaar): Expand to regions.
rewriter.create<OpWithInferTypeInterfaceOp>(
UnknownLoc::get(op->getContext()), values, op->getAttrs());
OpTy::build(&b, state, values, op->getAttrs());
(void)b.createOperation(state);
}
}
return matchFailure();
}
};
}
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns.insert<ReturnTypeOpMatch>(&getContext());
applyPatternsGreedily(getFunction(), patterns);
if (getFunction().getName() == "testCreateFunctions") {
std::vector<Operation *> ops;
// Collect ops to avoid triggering on inserted ops.
for (auto &op : getFunction().getBody().front())
ops.push_back(&op);
// Generate test patterns for each, but skip terminator.
for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
// Test create method of each of the Op classes below. The resultant
// output would be in reverse order underneath `op` from which
// the attributes and regions are used.
invokeCreateWithInferedReturnType<OpWithInferTypeInterfaceOp>(op);
invokeCreateWithInferedReturnType<OpWithShapedTypeInferTypeInterfaceOp>(
op);
};
return;
}
// Verification check.
// TODO: Move to ops that implement type infer interface.
getFunction().walk([this](Operation *op) -> void {
auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
if (!retTypeFn)
return;
auto *context = &getContext();
SmallVector<Type, 2> inferedReturnTypes;
if (failed(retTypeFn.inferReturnTypes(
context, op->getLoc(), op->getOperands(), op->getAttrs(),
op->getRegions(), inferedReturnTypes)))
return;
SmallVector<Type, 1> resultTypes(op->getResultTypes());
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
op->emitOpError(
"inferred type incompatible with return type of operation");
return;
}
});
}
};
} // end anonymous namespace

View File

@ -56,18 +56,18 @@ def AOp : NS_Op<"a_op", []> {
// ---
// DEF: void AOp::build(
// DEF: tblgen_state.addAttribute("aAttr", aAttr);
// DEF: tblgen_state.addAttribute("bAttr", bAttr);
// DEF: odsState.addAttribute("aAttr", aAttr);
// DEF: odsState.addAttribute("bAttr", bAttr);
// DEF: if (cAttr) {
// DEF-NEXT: tblgen_state.addAttribute("cAttr", cAttr);
// DEF-NEXT: odsState.addAttribute("cAttr", cAttr);
// DEF: void AOp::build(
// DEF: some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
// DEF: tblgen_state.addAttribute("aAttr", some-const-builder-call((*tblgen_builder), aAttr));
// DEF: odsState.addAttribute("aAttr", some-const-builder-call((*odsBuilder), aAttr));
// DEF: void AOp::build(
// DEF: ArrayRef<NamedAttribute> attributes
// DEF: tblgen_state.addAttributes(attributes);
// DEF: odsState.addAttributes(attributes);
// Test verify method
// ---
@ -218,7 +218,7 @@ def MixOperandsAndAttrs : NS_Op<"mix_operands_and_attrs", []> {
// DEF-LABEL: MixOperandsAndAttrs definitions
// DEF-DAG: Value MixOperandsAndAttrs::operand()
// DEF-DAG: Value MixOperandsAndAttrs::otherArg()
// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
// DEF-DAG: void MixOperandsAndAttrs::build(Builder *odsBuilder, OperationState &odsState, FloatAttr attr, Value operand, FloatAttr otherAttr, Value otherArg)
// DEF-DAG: APFloat MixOperandsAndAttrs::attr()
// DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr()
@ -233,4 +233,4 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> {
// DEF: bool UnitAttrOp::attr() {
// DEF: return {{.*}} != nullptr
// DEF: build(Builder *tblgen_builder, OperationState &tblgen_state, /*optional*/UnitAttr attr)
// DEF: build(Builder *odsBuilder, OperationState &odsState, /*optional*/UnitAttr attr)

View File

@ -70,9 +70,9 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
// CHECK: FloatAttr attr2Attr()
// CHECK: Optional< APFloat > attr2();
// CHECK: static void build(Value val);
// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2)
// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef<Type> s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2)
// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result);
// CHECK: void print(OpAsmPrinter &p);
// CHECK: LogicalResult verify();

View File

@ -19,12 +19,12 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK: void OpA::build
// CHECK: Value input
// CHECK: tblgen_state.addOperands(input);
// CHECK: odsState.addOperands(input);
// CHECK: void OpA::build
// CHECK: ValueRange operands
// CHECK: assert(operands.size() == 1u && "mismatched number of parameters");
// CHECK: tblgen_state.addOperands(operands);
// CHECK: odsState.addOperands(operands);
def OpB : NS_Op<"one_variadic_operand_op", []> {
let arguments = (ins Variadic<I32>:$input);
@ -33,7 +33,7 @@ def OpB : NS_Op<"one_variadic_operand_op", []> {
// CHECK-LABEL: OpB::build
// CHECK: ValueRange input
// CHECK-NOT: assert
// CHECK: tblgen_state.addOperands(input);
// CHECK: odsState.addOperands(input);
def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
@ -55,6 +55,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-NEXT: return *getODSOperands(1).begin();
// CHECK-LABEL: OpD::build
// CHECK-NEXT: tblgen_state.addOperands(input1);
// CHECK-NEXT: tblgen_state.addOperands(input2);
// CHECK-NEXT: tblgen_state.addOperands(input3);
// CHECK-NEXT: odsState.addOperands(input1);
// CHECK-NEXT: odsState.addOperands(input2);
// CHECK-NEXT: odsState.addOperands(input3);

View File

@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_result_op", []> {
// CHECK-LABEL: void OpA::build
// CHECK: ArrayRef<Type> resultTypes, ValueRange operands
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state.addTypes(resultTypes);
// CHECK-NEXT: odsState.addTypes(resultTypes);
def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
let arguments = (ins I32:$x);
@ -23,20 +23,20 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
}
// CHECK-LABEL: OpB definitions
// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value x)
// CHECK: tblgen_state.addTypes(y);
// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value x)
// CHECK: tblgen_state.addTypes({x.getType()});
// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Type y, Value x)
// CHECK: odsState.addTypes(y);
// CHECK: void OpB::build(Builder *odsBuilder, OperationState &odsState, Value x)
// CHECK: odsState.addTypes({x.getType()});
def OpC : NS_Op<"three_normal_result_op", []> {
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
}
// CHECK-LABEL: OpC definitions
// CHECK: void OpC::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, Type resultType1, Type z)
// CHECK-NEXT: tblgen_state.addTypes(x)
// CHECK-NEXT: tblgen_state.addTypes(resultType1)
// CHECK-NEXT: tblgen_state.addTypes(z)
// CHECK: void OpC::build(Builder *odsBuilder, OperationState &odsState, Type x, Type resultType1, Type z)
// CHECK-NEXT: odsState.addTypes(x)
// CHECK-NEXT: odsState.addTypes(resultType1)
// CHECK-NEXT: odsState.addTypes(z)
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
@ -45,8 +45,8 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
}
// CHECK-LABEL: OpD definitions
// CHECK: void OpD::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: tblgen_state.addTypes({attr.second.cast<TypeAttr>().getValue()});
// CHECK: void OpD::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: odsState.addTypes({attr.second.cast<TypeAttr>().getValue()});
def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr);
@ -54,8 +54,8 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
}
// CHECK-LABEL: OpE definitions
// CHECK: void OpE::build(Builder *, OperationState &tblgen_state, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: tblgen_state.addTypes({attr.second.getType()});
// CHECK: void OpE::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes)
// CHECK: odsState.addTypes({attr.second.getType()});
def OpF : NS_Op<"one_variadic_result_op", []> {
let results = (outs Variadic<I32>:$x);
@ -64,7 +64,7 @@ def OpF : NS_Op<"one_variadic_result_op", []> {
// CHECK-LABEL: void OpF::build
// CHECK-SAME: ArrayRef<Type> x
// CHECK-NOT: assert
// CHECK: tblgen_state.addTypes(x);
// CHECK: odsState.addTypes(x);
def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
@ -73,14 +73,14 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
// CHECK-LABEL: OpG definitions
// CHECK: void OpG::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, ArrayRef<Type> y)
// CHECK-NEXT: tblgen_state.addTypes(x);
// CHECK-NEXT: tblgen_state.addTypes(y);
// CHECK: void OpG::build(Builder *odsBuilder, OperationState &odsState, Type x, ArrayRef<Type> y)
// CHECK-NEXT: odsState.addTypes(x);
// CHECK-NEXT: odsState.addTypes(y);
// CHECK: void OpG::build
// CHECK: ArrayRef<Type> resultTypes
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state.addTypes(resultTypes);
// CHECK-NEXT: odsState.addTypes(resultTypes);
def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
@ -93,9 +93,9 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
// CHECK-NEXT: return *getODSResults(1).begin();
// CHECK-LABEL: OpI::build
// CHECK-NEXT: tblgen_state.addTypes(output1);
// CHECK-NEXT: tblgen_state.addTypes(output2);
// CHECK-NEXT: tblgen_state.addTypes(output3);
// CHECK-NEXT: odsState.addTypes(output1);
// CHECK-NEXT: odsState.addTypes(output2);
// CHECK-NEXT: odsState.addTypes(output3);
// Test that if the only operand is variadic, we access the first value in the
// pack to set result type
@ -105,5 +105,5 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
let results = (outs AnyTensor:$result);
}
// CHECK-LABEL: OpK::build(Builder *tblgen_builder, OperationState &tblgen_state, ValueRange input)
// CHECK: tblgen_state.addTypes({input.front().getType()});
// CHECK-LABEL: OpK::build(Builder *odsBuilder, OperationState &odsState, ValueRange input)
// CHECK: odsState.addTypes({input.front().getType()});

View File

@ -1,12 +1,23 @@
// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
%good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: test.op_with_infer_type_if
// CHECK-SAME: tensor<20xi32>
// CHECK: test.op_with_infer_type_if
// CHECK-SAME: tensor<10xf32>
// CHECK-LABEL: testCreateFunctions
// This function tests invoking the create method with different inference
// methods. The attributes of the ops inside are used to test creation.
func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
// CHECK: "test.no_attributes"
%good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_shaped_type_infer_type_if"
// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
// CHECK: "test.op_with_infer_type_if"
// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_infer_type_if"
// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32>
return
}

View File

@ -58,8 +58,8 @@ ODSDialectHookRegistration::ODSDialectHookRegistration(
//===----------------------------------------------------------------------===//
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "tblgen_arg";
static const char *const builderOpState = "tblgen_state";
static const char *const generatedArgName = "odsArg";
static const char *const builderOpState = "odsState";
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for
@ -627,8 +627,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
// TODO(jpienaar): Expand to handle regions.
body << formatv(R"(
SmallVector<Type, 2> inferedReturnTypes;
if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands,
{1}.attributes, /*regions=*/{{}, inferedReturnTypes)))
if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
{1}.location, {1}.operands, {1}.attributes,
/*regions=*/{{}, inferedReturnTypes)))
{1}.addTypes(inferedReturnTypes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
@ -702,7 +703,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
void OpEmitter::genInferedTypeCollectiveParamBuilder() {
// TODO(jpienaar): Expand to support regions.
const char *params =
"Builder *builder, OperationState &{0}, "
"Builder *odsBuilder, OperationState &{0}, "
"ValueRange operands, ArrayRef<NamedAttribute> attributes";
auto &m =
opClass.newMethod("void", "build", formatv(params, builderOpState).str(),
@ -710,9 +711,10 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() {
auto &body = m.body();
body << formatv(R"(
SmallVector<Type, 2> inferedReturnTypes;
if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes,
if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
{1}.location, operands, attributes,
/*regions=*/{{}, inferedReturnTypes)))
build(builder, tblgen_state, inferedReturnTypes, operands, attributes);
build(odsBuilder, odsState, inferedReturnTypes, operands, attributes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
@ -878,7 +880,7 @@ void OpEmitter::buildParamList(std::string &paramList,
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
paramList = "Builder *tblgen_builder, OperationState &";
paramList = "Builder *odsBuilder, OperationState &";
paramList.append(builderOpState);
switch (typeParamKind) {
@ -1000,7 +1002,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("(*tblgen_builder)");
fctx.withBuilder("(*odsBuilder)");
std::string value =
tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name);
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,