forked from OSchip/llvm-project
Move MemRefCastOp and TensorCastOp to the Op Definition Generation framework.
-- PiperOrigin-RevId: 247981385
This commit is contained in:
parent
17cc065da0
commit
5d7546470d
|
@ -886,34 +886,6 @@ ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
|
|||
void printCastOp(Operation *op, OpAsmPrinter *p);
|
||||
Value *foldCastOp(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
/// This template is used for operations that are cast operations, that have a
|
||||
/// single operand and single results, whose source and destination types are
|
||||
/// different.
|
||||
///
|
||||
/// From this structure, subclasses get a standard builder, parser and printer.
|
||||
///
|
||||
template <typename ConcreteType, template <typename T> class... Traits>
|
||||
class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
|
||||
OpTrait::HasNoSideEffect, Traits...> {
|
||||
public:
|
||||
using Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
|
||||
OpTrait::HasNoSideEffect, Traits...>::Op;
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Value *source,
|
||||
Type destType) {
|
||||
impl::buildCastOp(builder, result, source, destType);
|
||||
}
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result) {
|
||||
return impl::parseCastOp(parser, result);
|
||||
}
|
||||
void print(OpAsmPrinter *p) {
|
||||
return impl::printCastOp(this->getOperation(), p);
|
||||
}
|
||||
|
||||
Value *fold() { return impl::foldCastOp(this->getOperation()); }
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
||||
|
|
|
@ -32,12 +32,6 @@ namespace mlir {
|
|||
class AffineMap;
|
||||
class Builder;
|
||||
|
||||
namespace detail {
|
||||
/// A custom binary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
void printStandardBinaryOp(Operation *op, OpAsmPrinter *p);
|
||||
} // namespace detail
|
||||
|
||||
class StandardOpsDialect : public Dialect {
|
||||
public:
|
||||
StandardOpsDialect(MLIRContext *context);
|
||||
|
@ -579,38 +573,6 @@ public:
|
|||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// The "memref_cast" operation converts a memref from one type to an equivalent
|
||||
/// type with a compatible shape. The source and destination types are
|
||||
/// when both are memref types with the same element type, affine mappings,
|
||||
/// address space, and rank but where the individual dimensions may add or
|
||||
/// remove constant dimensions from the memref type.
|
||||
///
|
||||
/// If the cast converts any dimensions from an unknown to a known size, then it
|
||||
/// acts as an assertion that fails at runtime of the dynamic dimensions
|
||||
/// disagree with resultant destination size.
|
||||
///
|
||||
/// Assert that the input dynamic shape matches the destination static shape.
|
||||
/// %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
|
||||
/// Erase static shape information, replacing it with dynamic information.
|
||||
/// %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
|
||||
///
|
||||
class MemRefCastOp : public CastOp<MemRefCastOp> {
|
||||
public:
|
||||
using CastOp::CastOp;
|
||||
static StringRef getOperationName() { return "std.memref_cast"; }
|
||||
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// The "select" operation chooses one value based on a binary condition
|
||||
/// supplied as its first operand. If the value of the first operand is 1, the
|
||||
/// second operand is chosen, otherwise the third operand is chosen. The second
|
||||
|
@ -683,33 +645,6 @@ public:
|
|||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// The "tensor_cast" operation converts a tensor from one type to an equivalent
|
||||
/// type without changing any data elements. The source and destination types
|
||||
/// must both be tensor types with the same element type. If both are ranked
|
||||
/// then the rank should be the same and static dimensions should match. The
|
||||
/// operation is invalid if converting to a mismatching constant dimension.
|
||||
///
|
||||
/// Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
|
||||
///
|
||||
class TensorCastOp : public CastOp<TensorCastOp> {
|
||||
public:
|
||||
using CastOp::CastOp;
|
||||
|
||||
static StringRef getOperationName() { return "std.tensor_cast"; }
|
||||
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
|
||||
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// Prints dimension and symbol list.
|
||||
void printDimAndSymbolList(Operation::operand_iterator begin,
|
||||
Operation::operand_iterator end, unsigned numDims,
|
||||
|
|
|
@ -32,6 +32,29 @@ def Standard_Dialect : Dialect {
|
|||
let name = "std";
|
||||
}
|
||||
|
||||
// Base class for standard cast operations. Requires single operand and result,
|
||||
// but does not constrain them to specific types.
|
||||
class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Standard_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
|
||||
|
||||
let results = (outs AnyType);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value *source, Type destType", [{
|
||||
impl::buildCastOp(builder, result, source, destType);
|
||||
}]>];
|
||||
|
||||
let parser = [{
|
||||
return impl::parseCastOp(parser, result);
|
||||
}];
|
||||
let printer = [{
|
||||
return printStandardCastOp(this->getOperation(), p);
|
||||
}];
|
||||
let verifier = [{ return ::verifyCastOp(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// Base class for standard arithmetic operations. Requires operands and
|
||||
// results to be of the same type, but does not constrain them to specific
|
||||
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
|
||||
|
@ -46,7 +69,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
}];
|
||||
|
||||
let printer = [{
|
||||
return detail::printStandardBinaryOp(this->getOperation(), p);
|
||||
return printStandardBinaryOp(this->getOperation(), p);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -383,6 +406,38 @@ def ExtractElementOp : Op<Standard_Dialect, "extract_element", [NoSideEffect]> {
|
|||
let hasConstantFolder = 0b1;
|
||||
}
|
||||
|
||||
def MemRefCastOp : CastOp<"memref_cast"> {
|
||||
let summary = "memref cast operation";
|
||||
let description = [{
|
||||
The "memref_cast" operation converts a memref from one type to an equivalent
|
||||
type with a compatible shape. The source and destination types are
|
||||
when both are memref types with the same element type, affine mappings,
|
||||
address space, and rank but where the individual dimensions may add or
|
||||
remove constant dimensions from the memref type.
|
||||
|
||||
If the cast converts any dimensions from an unknown to a known size, then it
|
||||
acts as an assertion that fails at runtime of the dynamic dimensions
|
||||
disagree with resultant destination size.
|
||||
|
||||
Assert that the input dynamic shape matches the destination static shape.
|
||||
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
|
||||
Erase static shape information, replacing it with dynamic information.
|
||||
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins MemRef<AnyType>);
|
||||
let results = (outs MemRef<AnyType>);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def MulFOp : FloatArithmeticOp<"mulf"> {
|
||||
let summary = "foating point multiplication operation";
|
||||
let hasConstantFolder = 0b1;
|
||||
|
@ -453,6 +508,32 @@ def SubIOp : IntArithmeticOp<"subi"> {
|
|||
let hasCanonicalizer = 0b1;
|
||||
}
|
||||
|
||||
def TensorCastOp : CastOp<"tensor_cast"> {
|
||||
let summary = "tensor cast operation";
|
||||
let description = [{
|
||||
The "tensor_cast" operation converts a tensor from one type to an equivalent
|
||||
type without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
|
||||
Convert from unknown rank to rank 2 with unknown dimension sizes.
|
||||
%2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins Tensor);
|
||||
let results = (outs Tensor);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
/// the operation.
|
||||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||
let summary = "integer binary xor";
|
||||
let hasConstantFolder = 0b1;
|
||||
|
|
|
@ -38,7 +38,7 @@ using namespace mlir;
|
|||
|
||||
/// A custom binary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
||||
assert(op->getNumOperands() == 2 && "binary op should have two operands");
|
||||
assert(op->getNumResults() == 1 && "binary op should have one result");
|
||||
|
||||
|
@ -59,10 +59,29 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
|||
*p << " : " << op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
|
||||
*p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation verifier.
|
||||
template <typename T> static LogicalResult verifyCastOp(T op) {
|
||||
auto opType = op.getOperand()->getType();
|
||||
auto resType = op.getType();
|
||||
if (!T::areCastCompatible(opType, resType))
|
||||
return op.emitError("operand type ") << opType << " and result type "
|
||||
<< resType << " are cast incompatible";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"std", context) {
|
||||
addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
|
||||
MemRefCastOp, SelectOp, StoreOp, TensorCastOp,
|
||||
SelectOp, StoreOp,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/StandardOps/Ops.cpp.inc"
|
||||
>();
|
||||
|
@ -1783,21 +1802,7 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void MemRefCastOp::print(OpAsmPrinter *p) {
|
||||
*p << "memref_cast " << *getOperand() << " : " << getOperand()->getType()
|
||||
<< " to " << getType();
|
||||
}
|
||||
|
||||
LogicalResult MemRefCastOp::verify() {
|
||||
auto opType = getOperand()->getType();
|
||||
auto resType = getType();
|
||||
if (!areCastCompatible(opType, resType))
|
||||
return emitError(llvm::formatv(
|
||||
"operand type {0} and result type {1} are cast incompatible", opType,
|
||||
resType));
|
||||
|
||||
return success();
|
||||
}
|
||||
Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulFOp
|
||||
|
@ -2235,21 +2240,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void TensorCastOp::print(OpAsmPrinter *p) {
|
||||
*p << "tensor_cast " << *getOperand() << " : " << getOperand()->getType()
|
||||
<< " to " << getType();
|
||||
}
|
||||
|
||||
LogicalResult TensorCastOp::verify() {
|
||||
auto opType = getOperand()->getType();
|
||||
auto resType = getType();
|
||||
if (!areCastCompatible(opType, resType))
|
||||
return emitError(llvm::formatv(
|
||||
"operand type {0} and result type {1} are cast incompatible", opType,
|
||||
resType));
|
||||
|
||||
return success();
|
||||
}
|
||||
Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
|
|
Loading…
Reference in New Issue