forked from OSchip/llvm-project
Add unary ops and ExpOp to Standard Dialect.
PiperOrigin-RevId: 274152154
This commit is contained in:
parent
304e44a6b0
commit
00d2a37e32
|
@ -452,15 +452,38 @@ Example:
|
|||
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
|
||||
```
|
||||
|
||||
## Unary Operations
|
||||
|
||||
### 'exp' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
operation ::= ssa-id `=` `exp` ssa-use `:` type
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// Scalar natural exponential.
|
||||
%a = exp %b : f64
|
||||
|
||||
// SIMD vector element-wise natural exponential.
|
||||
%f = exp %g : vector<4xf32>
|
||||
|
||||
// Tensor element-wise natural exponential.
|
||||
%x = exp %y : tensor<4x?xf8>
|
||||
```
|
||||
|
||||
The `exp` operation takes one operand and returns one result of the same type.
|
||||
This type may be a float scalar type, a vector whose element type is float, or a
|
||||
tensor of floats. It has no standard attributes.
|
||||
|
||||
## Arithmetic Operations
|
||||
|
||||
Basic arithmetic in MLIR is specified by standard operations described in this
|
||||
section.
|
||||
|
||||
TODO: "sub" etc. Let's not get excited about filling this out yet, we can define
|
||||
these on demand. We should be highly informed by and learn from the operations
|
||||
supported by HLO and LLVM.
|
||||
|
||||
### 'addi' operation
|
||||
|
||||
Syntax:
|
||||
|
@ -478,7 +501,7 @@ Examples:
|
|||
// SIMD vector element-wise addition, e.g. for Intel SSE.
|
||||
%f = addi %g, %h : vector<4xi32>
|
||||
|
||||
// Tensor element-wise addition, analogous to HLO's add operation.
|
||||
// Tensor element-wise addition.
|
||||
%x = addi %y, %z : tensor<4x?xi8>
|
||||
```
|
||||
|
||||
|
@ -504,7 +527,7 @@ Examples:
|
|||
// SIMD vector addition, e.g. for Intel SSE.
|
||||
%f = addf %g, %h : vector<4xf32>
|
||||
|
||||
// Tensor addition, analogous to HLO's add operation.
|
||||
// Tensor addition.
|
||||
%x = addf %y, %z : tensor<4x?xbf16>
|
||||
```
|
||||
|
||||
|
@ -757,7 +780,7 @@ Examples:
|
|||
// SIMD pointwise vector multiplication, e.g. for Intel SSE.
|
||||
%f = mulf %g, %h : vector<4xf32>
|
||||
|
||||
// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation.
|
||||
// Tensor pointwise multiplication.
|
||||
%x = mulf %y, %z : tensor<4x?xbf16>
|
||||
```
|
||||
|
||||
|
|
|
@ -72,6 +72,27 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// Base class for unary ops. Requires single operand and result. Individual
|
||||
// classes will have `operand` accessor.
|
||||
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
|
||||
let results = (outs AnyType);
|
||||
let printer = [{
|
||||
return printStandardUnaryOp(this->getOperation(), p);
|
||||
}];
|
||||
}
|
||||
|
||||
class UnaryOpSameOperandAndResultType<string mnemonic, list<OpTrait> traits = []> :
|
||||
UnaryOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType])> {
|
||||
let parser = [{
|
||||
return impl::parseOneResultSameOperandTypeOp(parser, result);
|
||||
}];
|
||||
}
|
||||
|
||||
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
UnaryOpSameOperandAndResultType<mnemonic, traits>,
|
||||
Arguments<(ins FloatLike:$operand)>;
|
||||
|
||||
// Base class for standard arithmetic operations. Requires operands and
|
||||
// results to be of the same type, but does not constrain them to specific
|
||||
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
|
||||
|
@ -597,6 +618,10 @@ def DivIUOp : IntArithmeticOp<"diviu"> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ExpOp : FloatUnaryOp<"exp"> {
|
||||
let summary = "base-e exponential of the specified value";
|
||||
}
|
||||
|
||||
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
||||
let summary = "element extract operation";
|
||||
let description = [{
|
||||
|
|
|
@ -1141,13 +1141,17 @@ private:
|
|||
Concept *impl;
|
||||
};
|
||||
|
||||
// These functions are out-of-line implementations of the methods in BinaryOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
// These functions are out-of-line implementations of the methods in UnaryOp and
|
||||
// BinaryOp, which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
|
||||
OperationState &result);
|
||||
|
||||
void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
|
||||
Value *rhs);
|
||||
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
||||
OperationState &result);
|
||||
|
||||
// Prints the given binary `op` in custom assembly form if both the two operands
|
||||
// and the result have the same time. Otherwise, prints the generic assembly
|
||||
// form.
|
||||
|
|
|
@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
|
|||
return res;
|
||||
}
|
||||
|
||||
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
|
||||
static_assert(
|
||||
std::is_base_of<
|
||||
typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
|
||||
SourceOp>::value,
|
||||
"wrong operand count");
|
||||
};
|
||||
|
||||
template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
|
||||
static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
|
||||
"expected a single operand");
|
||||
};
|
||||
|
||||
template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
|
||||
OpCountValidator<SourceOp, OpCount>();
|
||||
}
|
||||
|
||||
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
|
||||
// Ops for binary ops with one result. This supports higher-dimensional vector
|
||||
// Ops for N-ary ops with one result. This supports higher-dimensional vector
|
||||
// types.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
||||
template <typename SourceOp, typename TargetOp, unsigned OpCount>
|
||||
struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
||||
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
|
||||
using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
|
||||
using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
|
||||
|
||||
// Convert the type of the result to an LLVM type, pass operands as is,
|
||||
// preserve attributes.
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
|
||||
"expected binary op");
|
||||
ValidateOpCount<SourceOp, OpCount>();
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
|
||||
SourceOp>::value,
|
||||
"expected single result op");
|
||||
"expected same operands and result type");
|
||||
|
||||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
for (Value *operand : operands) {
|
||||
|
@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
arraySizes.push_back(llvmTy.getArrayNumElements());
|
||||
llvmTy = llvmTy.getArrayElementType();
|
||||
}
|
||||
assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
|
||||
assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
|
||||
auto llvmVectorTy = llvmTy;
|
||||
|
||||
// Iteratively extract a position coordinates with basis `arraySize` from a
|
||||
|
@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[0], position);
|
||||
Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[1], position);
|
||||
SmallVector<Value *, OpCount> extractedOperands;
|
||||
for (unsigned i = 0; i < OpCount; ++i) {
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[i], position));
|
||||
}
|
||||
Value *newVal = rewriter.create<TargetOp>(
|
||||
loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
|
||||
op->getAttrs());
|
||||
loc, llvmVectorTy, extractedOperands, op->getAttrs());
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
|
||||
newVal, position);
|
||||
}
|
||||
|
@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
|
||||
|
||||
// Specific lowerings.
|
||||
// FIXME: this should be tablegen'ed.
|
||||
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
|
@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
|
|||
void mlir::populateStdToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
// FIXME: this should be tablegen'ed
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
|
||||
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
|
||||
CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
|
||||
DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
|
||||
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
|
||||
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
||||
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
||||
SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering,
|
||||
SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering,
|
||||
SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
|
||||
AddFOpLowering,
|
||||
AddIOpLowering,
|
||||
AllocOpLowering,
|
||||
AndOpLowering,
|
||||
BranchOpLowering,
|
||||
CallIndirectOpLowering,
|
||||
CallOpLowering,
|
||||
CmpFOpLowering,
|
||||
CmpIOpLowering,
|
||||
CondBranchOpLowering,
|
||||
ConstLLVMOpLowering,
|
||||
DeallocOpLowering,
|
||||
DimOpLowering,
|
||||
DivFOpLowering,
|
||||
DivISOpLowering,
|
||||
DivIUOpLowering,
|
||||
ExpOpLowering,
|
||||
FPExtLowering,
|
||||
FPTruncLowering,
|
||||
FuncOpConversion,
|
||||
IndexCastOpLowering,
|
||||
LoadOpLowering,
|
||||
MemRefCastOpLowering,
|
||||
MulFOpLowering,
|
||||
MulIOpLowering,
|
||||
OrOpLowering,
|
||||
RemFOpLowering,
|
||||
RemISOpLowering,
|
||||
RemIUOpLowering,
|
||||
ReturnOpLowering,
|
||||
SIToFPLowering,
|
||||
SelectOpLowering,
|
||||
SignExtendIOpLowering,
|
||||
SplatOpLowering,
|
||||
StoreOpLowering,
|
||||
SubFOpLowering,
|
||||
SubIOpLowering,
|
||||
TruncateIOpLowering,
|
||||
XOrOpLowering,
|
||||
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
|
|
@ -124,6 +124,19 @@ struct StdInlinerInterface : public DialectInlinerInterface {
|
|||
// StandardOpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A custom unary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
|
||||
assert(op->getNumOperands() == 1 && "unary op should have one operand");
|
||||
assert(op->getNumResults() == 1 && "unary op should have one result");
|
||||
|
||||
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : " << op->getOperand(0)->getType();
|
||||
}
|
||||
|
||||
/// A custom binary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
||||
|
@ -139,7 +152,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
return;
|
||||
}
|
||||
|
||||
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0) << ", " << *op->getOperand(1);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
|
||||
|
@ -150,7 +164,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
/// A custom cast operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
|
||||
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
}
|
||||
|
|
|
@ -421,6 +421,8 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
|
|||
%12 = or %arg2, %arg3 : i32
|
||||
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
|
||||
%13 = xor %arg2, %arg3 : i32
|
||||
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
|
||||
%14 = std.exp %arg0 : f32
|
||||
|
||||
return %0, %4 : f32, i32
|
||||
}
|
||||
|
|
|
@ -351,6 +351,14 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
|
|||
// CHECK: = fptrunc {{.*}} : f32 to f16
|
||||
%95 = fptrunc %f : f32 to f16
|
||||
|
||||
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
|
||||
%96 = "std.exp"(%f) : (f32) -> f32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
|
||||
%97 = exp %f : f32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32>
|
||||
%98 = exp %t : tensor<4x4x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue