Add unary ops and ExpOp to Standard Dialect.

PiperOrigin-RevId: 274152154
This commit is contained in:
Alexander Belyaev 2019-10-11 05:13:18 -07:00 committed by A. Unique TensorFlower
parent 304e44a6b0
commit 00d2a37e32
7 changed files with 167 additions and 36 deletions

View File

@ -452,15 +452,38 @@ Example:
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
``` ```
## Unary Operations
### 'exp' operation
Syntax:
``` {.ebnf}
operation ::= ssa-id `=` `exp` ssa-use `:` type
```
Examples:
```mlir {.mlir}
// Scalar natural exponential.
%a = exp %b : f64
// SIMD vector element-wise natural exponential.
%f = exp %g : vector<4xf32>
// Tensor element-wise natural exponential.
%x = exp %y : tensor<4x?xf8>
```
The `exp` operation takes one operand and returns one result of the same type.
This type may be a float scalar type, a vector whose element type is float, or a
tensor of floats. It has no standard attributes.
## Arithmetic Operations ## Arithmetic Operations
Basic arithmetic in MLIR is specified by standard operations described in this Basic arithmetic in MLIR is specified by standard operations described in this
section. section.
TODO: "sub" etc. Let's not get excited about filling this out yet, we can define
these on demand. We should be highly informed by and learn from the operations
supported by HLO and LLVM.
### 'addi' operation ### 'addi' operation
Syntax: Syntax:
@ -478,7 +501,7 @@ Examples:
// SIMD vector element-wise addition, e.g. for Intel SSE. // SIMD vector element-wise addition, e.g. for Intel SSE.
%f = addi %g, %h : vector<4xi32> %f = addi %g, %h : vector<4xi32>
// Tensor element-wise addition, analogous to HLO's add operation. // Tensor element-wise addition.
%x = addi %y, %z : tensor<4x?xi8> %x = addi %y, %z : tensor<4x?xi8>
``` ```
@ -504,7 +527,7 @@ Examples:
// SIMD vector addition, e.g. for Intel SSE. // SIMD vector addition, e.g. for Intel SSE.
%f = addf %g, %h : vector<4xf32> %f = addf %g, %h : vector<4xf32>
// Tensor addition, analogous to HLO's add operation. // Tensor addition.
%x = addf %y, %z : tensor<4x?xbf16> %x = addf %y, %z : tensor<4x?xbf16>
``` ```
@ -757,7 +780,7 @@ Examples:
// SIMD pointwise vector multiplication, e.g. for Intel SSE. // SIMD pointwise vector multiplication, e.g. for Intel SSE.
%f = mulf %g, %h : vector<4xf32> %f = mulf %g, %h : vector<4xf32>
// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation. // Tensor pointwise multiplication.
%x = mulf %y, %z : tensor<4x?xbf16> %x = mulf %y, %z : tensor<4x?xbf16>
``` ```

View File

@ -72,6 +72,27 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let hasFolder = 1; let hasFolder = 1;
} }
// Base class for unary ops. Requires single operand and result. Individual
// classes will have `operand` accessor.
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
let results = (outs AnyType);
let printer = [{
return printStandardUnaryOp(this->getOperation(), p);
}];
}
class UnaryOpSameOperandAndResultType<string mnemonic, list<OpTrait> traits = []> :
UnaryOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType])> {
let parser = [{
return impl::parseOneResultSameOperandTypeOp(parser, result);
}];
}
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
UnaryOpSameOperandAndResultType<mnemonic, traits>,
Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and // Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific // results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands. // types. Individual classes will have `lhs` and `rhs` accessor to operands.
@ -597,6 +618,10 @@ def DivIUOp : IntArithmeticOp<"diviu"> {
let hasFolder = 1; let hasFolder = 1;
} }
def ExpOp : FloatUnaryOp<"exp"> {
let summary = "base-e exponential of the specified value";
}
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
let summary = "element extract operation"; let summary = "element extract operation";
let description = [{ let description = [{

View File

@ -1141,13 +1141,17 @@ private:
Concept *impl; Concept *impl;
}; };
// These functions are out-of-line implementations of the methods in BinaryOp, // These functions are out-of-line implementations of the methods in UnaryOp and
// which avoids them being template instantiated/duplicated. // BinaryOp, which avoids them being template instantiated/duplicated.
namespace impl { namespace impl {
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);
void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs, void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
Value *rhs); Value *rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result); OperationState &result);
// Prints the given binary `op` in custom assembly form if both the two operands // Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly // and the result have the same time. Otherwise, prints the generic assembly
// form. // form.

View File

@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
return res; return res;
} }
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
static_assert(
std::is_base_of<
typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
SourceOp>::value,
"wrong operand count");
};
template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
"expected a single operand");
};
template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for binary ops with one result. This supports higher-dimensional vector // Ops for N-ary ops with one result. This supports higher-dimensional vector
// types. // types.
template <typename SourceOp, typename TargetOp> template <typename SourceOp, typename TargetOp, unsigned OpCount>
struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>; using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
// Convert the type of the result to an LLVM type, pass operands as is, // Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes. // preserve attributes.
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands, matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
static_assert( ValidateOpCount<SourceOp, OpCount>();
std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
"expected binary op");
static_assert( static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op"); "expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value, SourceOp>::value,
"expected single result op"); "expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type. // Cannot convert ops if their operands are not of LLVM type.
for (Value *operand : operands) { for (Value *operand : operands) {
@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
arraySizes.push_back(llvmTy.getArrayNumElements()); arraySizes.push_back(llvmTy.getArrayNumElements());
llvmTy = llvmTy.getArrayElementType(); llvmTy = llvmTy.getArrayElementType();
} }
assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
auto llvmVectorTy = llvmTy; auto llvmVectorTy = llvmTy;
// Iteratively extract a position coordinates with basis `arraySize` from a // Iteratively extract a position coordinates with basis `arraySize` from a
@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// For this unrolled `position` corresponding to the `linearIndex`^th // For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors // element, extract operand vectors
Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>( SmallVector<Value *, OpCount> extractedOperands;
loc, llvmVectorTy, operands[0], position); for (unsigned i = 0; i < OpCount; ++i) {
Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>( extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[1], position); loc, llvmVectorTy, operands[i], position));
}
Value *newVal = rewriter.create<TargetOp>( Value *newVal = rewriter.create<TargetOp>(
loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS}, loc, llvmVectorTy, extractedOperands, op->getAttrs());
op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position); newVal, position);
} }
@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
} }
}; };
template <typename SourceOp, typename TargetOp>
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
template <typename SourceOp, typename TargetOp>
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
// Specific lowerings. // Specific lowerings.
// FIXME: this should be tablegen'ed. // FIXME: this should be tablegen'ed.
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
using Super::Super;
};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super; using Super::Super;
}; };
@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
void mlir::populateStdToLLVMConversionPatterns( void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed // FIXME: this should be tablegen'ed
// clang-format off
patterns.insert< patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, AddFOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, AddIOpLowering,
CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, AllocOpLowering,
DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, AndOpLowering,
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, BranchOpLowering,
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, CallIndirectOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, CallOpLowering,
SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering, CmpFOpLowering,
SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering, CmpIOpLowering,
SubIOpLowering, TruncateIOpLowering, XOrOpLowering, CondBranchOpLowering,
ConstLLVMOpLowering,
DeallocOpLowering,
DimOpLowering,
DivFOpLowering,
DivISOpLowering,
DivIUOpLowering,
ExpOpLowering,
FPExtLowering,
FPTruncLowering,
FuncOpConversion,
IndexCastOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MulFOpLowering,
MulIOpLowering,
OrOpLowering,
RemFOpLowering,
RemISOpLowering,
RemIUOpLowering,
ReturnOpLowering,
SIToFPLowering,
SelectOpLowering,
SignExtendIOpLowering,
SplatOpLowering,
StoreOpLowering,
SubFOpLowering,
SubIOpLowering,
TruncateIOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter); ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
} }
// Convert types using the stored LLVM IR module. // Convert types using the stored LLVM IR module.

View File

@ -124,6 +124,19 @@ struct StdInlinerInterface : public DialectInlinerInterface {
// StandardOpsDialect // StandardOpsDialect
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// A custom unary operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumOperands() == 1 && "unary op should have one operand");
assert(op->getNumResults() == 1 && "unary op should have one result");
const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0);
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getOperand(0)->getType();
}
/// A custom binary operation printer that omits the "std." prefix from the /// A custom binary operation printer that omits the "std." prefix from the
/// operation names. /// operation names.
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
@ -139,7 +152,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
return; return;
} }
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << ", " << *op->getOperand(1); << *op->getOperand(0) << ", " << *op->getOperand(1);
p.printOptionalAttrDict(op->getAttrs()); p.printOptionalAttrDict(op->getAttrs());
@ -150,7 +164,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
/// A custom cast operation printer that omits the "std." prefix from the /// A custom cast operation printer that omits the "std." prefix from the
/// operation names. /// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
<< op->getResult(0)->getType(); << op->getResult(0)->getType();
} }

View File

@ -421,6 +421,8 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
%12 = or %arg2, %arg3 : i32 %12 = or %arg2, %arg3 : i32
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
%13 = xor %arg2, %arg3 : i32 %13 = xor %arg2, %arg3 : i32
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
%14 = std.exp %arg0 : f32
return %0, %4 : f32, i32 return %0, %4 : f32, i32
} }

View File

@ -351,6 +351,14 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: = fptrunc {{.*}} : f32 to f16 // CHECK: = fptrunc {{.*}} : f32 to f16
%95 = fptrunc %f : f32 to f16 %95 = fptrunc %f : f32 to f16
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
%96 = "std.exp"(%f) : (f32) -> f32
// CHECK: %{{[0-9]+}} = exp %arg1 : f32
%97 = exp %f : f32
// CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32>
%98 = exp %t : tensor<4x4x?xf32>
return return
} }