Add unary ops and ExpOp to Standard Dialect.

PiperOrigin-RevId: 274152154
This commit is contained in:
Alexander Belyaev 2019-10-11 05:13:18 -07:00 committed by A. Unique TensorFlower
parent 304e44a6b0
commit 00d2a37e32
7 changed files with 167 additions and 36 deletions

View File

@ -452,15 +452,38 @@ Example:
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
```
## Unary Operations
### 'exp' operation
Syntax:
``` {.ebnf}
operation ::= ssa-id `=` `exp` ssa-use `:` type
```
Examples:
```mlir {.mlir}
// Scalar natural exponential.
%a = exp %b : f64
// SIMD vector element-wise natural exponential.
%f = exp %g : vector<4xf32>
// Tensor element-wise natural exponential.
%x = exp %y : tensor<4x?xf8>
```
The `exp` operation takes one operand and returns one result of the same type.
This type may be a float scalar type, a vector whose element type is float, or a
tensor of floats. It has no standard attributes.
## Arithmetic Operations
Basic arithmetic in MLIR is specified by standard operations described in this
section.
TODO: "sub" etc. Let's not get excited about filling this out yet, we can define
these on demand. We should be highly informed by and learn from the operations
supported by HLO and LLVM.
### 'addi' operation
Syntax:
@ -478,7 +501,7 @@ Examples:
// SIMD vector element-wise addition, e.g. for Intel SSE.
%f = addi %g, %h : vector<4xi32>
// Tensor element-wise addition, analogous to HLO's add operation.
// Tensor element-wise addition.
%x = addi %y, %z : tensor<4x?xi8>
```
@ -504,7 +527,7 @@ Examples:
// SIMD vector addition, e.g. for Intel SSE.
%f = addf %g, %h : vector<4xf32>
// Tensor addition, analogous to HLO's add operation.
// Tensor addition.
%x = addf %y, %z : tensor<4x?xbf16>
```
@ -757,7 +780,7 @@ Examples:
// SIMD pointwise vector multiplication, e.g. for Intel SSE.
%f = mulf %g, %h : vector<4xf32>
// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation.
// Tensor pointwise multiplication.
%x = mulf %y, %z : tensor<4x?xbf16>
```

View File

@ -72,6 +72,27 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let hasFolder = 1;
}
// Base class for unary ops. Requires single operand and result. Individual
// classes will have `operand` accessor.
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
let results = (outs AnyType);
let printer = [{
return printStandardUnaryOp(this->getOperation(), p);
}];
}
class UnaryOpSameOperandAndResultType<string mnemonic, list<OpTrait> traits = []> :
UnaryOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType])> {
let parser = [{
return impl::parseOneResultSameOperandTypeOp(parser, result);
}];
}
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
UnaryOpSameOperandAndResultType<mnemonic, traits>,
Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
@ -597,6 +618,10 @@ def DivIUOp : IntArithmeticOp<"diviu"> {
let hasFolder = 1;
}
def ExpOp : FloatUnaryOp<"exp"> {
let summary = "base-e exponential of the specified value";
}
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
let summary = "element extract operation";
let description = [{

View File

@ -1141,13 +1141,17 @@ private:
Concept *impl;
};
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
// These functions are out-of-line implementations of the methods in UnaryOp and
// BinaryOp, which avoids them being template instantiated/duplicated.
namespace impl {
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);
void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
Value *rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result);
// Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly
// form.

View File

@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
return res;
}
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
static_assert(
std::is_base_of<
typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
SourceOp>::value,
"wrong operand count");
};
template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
"expected a single operand");
};
template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for binary ops with one result. This supports higher-dimensional vector
// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
template <typename SourceOp, typename TargetOp>
struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
template <typename SourceOp, typename TargetOp, unsigned OpCount>
struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
"expected binary op");
ValidateOpCount<SourceOp, OpCount>();
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected single result op");
"expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type.
for (Value *operand : operands) {
@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
arraySizes.push_back(llvmTy.getArrayNumElements());
llvmTy = llvmTy.getArrayElementType();
}
assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
auto llvmVectorTy = llvmTy;
// Iteratively extract a position coordinates with basis `arraySize` from a
@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[0], position);
Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[1], position);
SmallVector<Value *, OpCount> extractedOperands;
for (unsigned i = 0; i < OpCount; ++i) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[i], position));
}
Value *newVal = rewriter.create<TargetOp>(
loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
op->getAttrs());
loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
}
@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
}
};
template <typename SourceOp, typename TargetOp>
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
template <typename SourceOp, typename TargetOp>
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
// Specific lowerings.
// FIXME: this should be tablegen'ed.
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
using Super::Super;
};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
// clang-format off
patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering,
SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering,
SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
AddFOpLowering,
AddIOpLowering,
AllocOpLowering,
AndOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,
CmpFOpLowering,
CmpIOpLowering,
CondBranchOpLowering,
ConstLLVMOpLowering,
DeallocOpLowering,
DimOpLowering,
DivFOpLowering,
DivISOpLowering,
DivIUOpLowering,
ExpOpLowering,
FPExtLowering,
FPTruncLowering,
FuncOpConversion,
IndexCastOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MulFOpLowering,
MulIOpLowering,
OrOpLowering,
RemFOpLowering,
RemISOpLowering,
RemIUOpLowering,
ReturnOpLowering,
SIToFPLowering,
SelectOpLowering,
SignExtendIOpLowering,
SplatOpLowering,
StoreOpLowering,
SubFOpLowering,
SubIOpLowering,
TruncateIOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
}
// Convert types using the stored LLVM IR module.

View File

@ -124,6 +124,19 @@ struct StdInlinerInterface : public DialectInlinerInterface {
// StandardOpsDialect
//===----------------------------------------------------------------------===//
/// A custom unary operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumOperands() == 1 && "unary op should have one operand");
assert(op->getNumResults() == 1 && "unary op should have one result");
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0);
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getOperand(0)->getType();
}
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
@ -139,7 +152,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
return;
}
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << ", " << *op->getOperand(1);
p.printOptionalAttrDict(op->getAttrs());
@ -150,7 +164,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
/// A custom cast operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
<< op->getResult(0)->getType();
}

View File

@ -421,6 +421,8 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
%12 = or %arg2, %arg3 : i32
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
%13 = xor %arg2, %arg3 : i32
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
%14 = std.exp %arg0 : f32
return %0, %4 : f32, i32
}

View File

@ -351,6 +351,14 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: = fptrunc {{.*}} : f32 to f16
%95 = fptrunc %f : f32 to f16
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
%96 = "std.exp"(%f) : (f32) -> f32
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
%97 = exp %f : f32
// CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32>
%98 = exp %t : tensor<4x4x?xf32>
return
}