Move MemRefCastOp and TensorCastOp to the Op Definition Generation framework.

--

PiperOrigin-RevId: 247981385
This commit is contained in:
River Riddle 2019-05-13 11:56:21 -07:00 committed by Mehdi Amini
parent 17cc065da0
commit 5d7546470d
4 changed files with 105 additions and 126 deletions

View File

@ -886,34 +886,6 @@ ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(Operation *op, OpAsmPrinter *p);
Value *foldCastOp(Operation *op);
} // namespace impl
/// This template is used for operations that are cast operations, that have a
/// single operand and single results, whose source and destination types are
/// different.
///
/// From this structure, subclasses get a standard builder, parser and printer.
///
template <typename ConcreteType, template <typename T> class... Traits>
class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...> {
public:
using Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...>::Op;
static void build(Builder *builder, OperationState *result, Value *source,
Type destType) {
impl::buildCastOp(builder, result, source, destType);
}
static ParseResult parse(OpAsmParser *parser, OperationState *result) {
return impl::parseCastOp(parser, result);
}
void print(OpAsmPrinter *p) {
return impl::printCastOp(this->getOperation(), p);
}
Value *fold() { return impl::foldCastOp(this->getOperation()); }
};
} // end namespace mlir
#endif

View File

@ -32,12 +32,6 @@ namespace mlir {
class AffineMap;
class Builder;
namespace detail {
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
void printStandardBinaryOp(Operation *op, OpAsmPrinter *p);
} // namespace detail
class StandardOpsDialect : public Dialect {
public:
StandardOpsDialect(MLIRContext *context);
@ -579,38 +573,6 @@ public:
MLIRContext *context);
};
/// The "memref_cast" operation converts a memref from one type to an equivalent
/// type with a compatible shape. The source and destination types are
/// when both are memref types with the same element type, affine mappings,
/// address space, and rank but where the individual dimensions may add or
/// remove constant dimensions from the memref type.
///
/// If the cast converts any dimensions from an unknown to a known size, then it
/// acts as an assertion that fails at runtime of the dynamic dimensions
/// disagree with resultant destination size.
///
/// Assert that the input dynamic shape matches the destination static shape.
/// %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
/// Erase static shape information, replacing it with dynamic information.
/// %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
///
class MemRefCastOp : public CastOp<MemRefCastOp> {
public:
using CastOp::CastOp;
static StringRef getOperationName() { return "std.memref_cast"; }
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a memref_cast is always a memref.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "select" operation chooses one value based on a binary condition
/// supplied as its first operand. If the value of the first operand is 1, the
/// second operand is chosen, otherwise the third operand is chosen. The second
@ -683,33 +645,6 @@ public:
MLIRContext *context);
};
/// The "tensor_cast" operation converts a tensor from one type to an equivalent
/// type without changing any data elements. The source and destination types
/// must both be tensor types with the same element type. If both are ranked
/// then the rank should be the same and static dimensions should match. The
/// operation is invalid if converting to a mismatching constant dimension.
///
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
///
class TensorCastOp : public CastOp<TensorCastOp> {
public:
using CastOp::CastOp;
static StringRef getOperationName() { return "std.tensor_cast"; }
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor_cast is always a tensor.
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,

View File

@ -32,6 +32,29 @@ def Standard_Dialect : Dialect {
let name = "std";
}
// Base class for standard cast operations. Requires single operand and result,
// but does not constrain them to specific types.
class CastOp<string mnemonic, list<OpTrait> traits = []> :
Op<Standard_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
let results = (outs AnyType);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Value *source, Type destType", [{
impl::buildCastOp(builder, result, source, destType);
}]>];
let parser = [{
return impl::parseCastOp(parser, result);
}];
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
let verifier = [{ return ::verifyCastOp(*this); }];
let hasFolder = 1;
}
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
@ -46,7 +69,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
}];
let printer = [{
return detail::printStandardBinaryOp(this->getOperation(), p);
return printStandardBinaryOp(this->getOperation(), p);
}];
}
@ -383,6 +406,38 @@ def ExtractElementOp : Op<Standard_Dialect, "extract_element", [NoSideEffect]> {
let hasConstantFolder = 0b1;
}
def MemRefCastOp : CastOp<"memref_cast"> {
let summary = "memref cast operation";
let description = [{
The "memref_cast" operation converts a memref from one type to an equivalent
type with a compatible shape. The source and destination types are
when both are memref types with the same element type, affine mappings,
address space, and rank but where the individual dimensions may add or
remove constant dimensions from the memref type.
If the cast converts any dimensions from an unknown to a known size, then it
acts as an assertion that fails at runtime of the dynamic dimensions
disagree with resultant destination size.
Assert that the input dynamic shape matches the destination static shape.
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
Erase static shape information, replacing it with dynamic information.
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
}];
let arguments = (ins MemRef<AnyType>);
let results = (outs MemRef<AnyType>);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a memref_cast is always a memref.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
}];
}
def MulFOp : FloatArithmeticOp<"mulf"> {
let summary = "foating point multiplication operation";
let hasConstantFolder = 0b1;
@ -453,6 +508,32 @@ def SubIOp : IntArithmeticOp<"subi"> {
let hasCanonicalizer = 0b1;
}
def TensorCastOp : CastOp<"tensor_cast"> {
let summary = "tensor cast operation";
let description = [{
The "tensor_cast" operation converts a tensor from one type to an equivalent
type without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
Convert from unknown rank to rank 2 with unknown dimension sizes.
%2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
}];
let arguments = (ins Tensor);
let results = (outs Tensor);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for
/// the operation.
static bool areCastCompatible(Type a, Type b);
/// The result of a tensor_cast is always a tensor.
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
}];
}
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasConstantFolder = 0b1;

View File

@ -38,7 +38,7 @@ using namespace mlir;
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
assert(op->getNumOperands() == 2 && "binary op should have two operands");
assert(op->getNumResults() == 1 && "binary op should have one result");
@ -59,10 +59,29 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
*p << " : " << op->getResult(0)->getType();
}
/// A custom cast operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
*p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
<< op->getResult(0)->getType();
}
/// A custom cast operation verifier.
template <typename T> static LogicalResult verifyCastOp(T op) {
auto opType = op.getOperand()->getType();
auto resType = op.getType();
if (!T::areCastCompatible(opType, resType))
return op.emitError("operand type ") << opType << " and result type "
<< resType << " are cast incompatible";
return success();
}
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
MemRefCastOp, SelectOp, StoreOp, TensorCastOp,
SelectOp, StoreOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
@ -1783,21 +1802,7 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
return true;
}
void MemRefCastOp::print(OpAsmPrinter *p) {
*p << "memref_cast " << *getOperand() << " : " << getOperand()->getType()
<< " to " << getType();
}
LogicalResult MemRefCastOp::verify() {
auto opType = getOperand()->getType();
auto resType = getType();
if (!areCastCompatible(opType, resType))
return emitError(llvm::formatv(
"operand type {0} and result type {1} are cast incompatible", opType,
resType));
return success();
}
Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
//===----------------------------------------------------------------------===//
// MulFOp
@ -2235,21 +2240,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
return true;
}
void TensorCastOp::print(OpAsmPrinter *p) {
*p << "tensor_cast " << *getOperand() << " : " << getOperand()->getType()
<< " to " << getType();
}
LogicalResult TensorCastOp::verify() {
auto opType = getOperand()->getType();
auto resType = getType();
if (!areCastCompatible(opType, resType))
return emitError(llvm::formatv(
"operand type {0} and result type {1} are cast incompatible", opType,
resType));
return success();
}
Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions