forked from OSchip/llvm-project
[mlir] Add fma operation to std dialect
Will remove `vector.fma` operation in the followup CLs. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D96801
This commit is contained in:
parent
fb19400d4e
commit
519f5917b4
|
@ -103,7 +103,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
// Base class for standard arithmetic operations. Requires operands and
|
||||
// results to be of the same type, but does not constrain them to specific
|
||||
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
|
||||
// types.
|
||||
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<StandardOps_Dialect, mnemonic,
|
||||
!listconcat(traits, [NoSideEffect,
|
||||
|
@ -122,6 +122,32 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
}];
|
||||
}
|
||||
|
||||
// Base class for standard binary arithmetic operations.
|
||||
class ArithmeticBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic, traits> {
|
||||
|
||||
let parser = [{
|
||||
return impl::parseOneResultSameOperandTypeOp(parser, result);
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
return printStandardBinaryOp(this->getOperation(), p);
|
||||
}];
|
||||
}
|
||||
|
||||
// Base class for standard ternary arithmetic operations.
|
||||
class ArithmeticTernaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic, traits> {
|
||||
|
||||
let parser = [{
|
||||
return impl::parseOneResultSameOperandTypeOp(parser, result);
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
return printStandardTernaryOp(this->getOperation(), p);
|
||||
}];
|
||||
}
|
||||
|
||||
// Base class for standard arithmetic operations on integers, vectors and
|
||||
// tensors thereof. This operation takes two operands and returns one result,
|
||||
// each of these is required to be of the same type. This type may be an
|
||||
|
@ -130,8 +156,8 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
//
|
||||
// <op>i %0, %1 : i32
|
||||
//
|
||||
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic,
|
||||
class IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticBinaryOp<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
|
||||
|
@ -145,12 +171,27 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
//
|
||||
// <op>f %0, %1 : f32
|
||||
//
|
||||
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic,
|
||||
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticBinaryOp<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
|
||||
|
||||
// Base class for standard arithmetic ternary operations on floats, vectors and
|
||||
// tensors thereof. This operation has three operands and returns one result,
|
||||
// each of these is required to be of the same type. This type may be a
|
||||
// floating point scalar type, a vector whose element type is a floating point
|
||||
// type, or a floating point tensor. The custom assembly form of the operation
|
||||
// is as follows
|
||||
//
|
||||
// <op> %0, %1, %2 : f32
|
||||
//
|
||||
class FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticTernaryOp<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>;
|
||||
|
||||
// Base class for memref allocating ops: alloca and alloc.
|
||||
//
|
||||
// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
|
||||
|
@ -257,7 +298,7 @@ def AbsFOp : FloatUnaryOp<"absf"> {
|
|||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AddFOp : FloatArithmeticOp<"addf"> {
|
||||
def AddFOp : FloatBinaryOp<"addf"> {
|
||||
let summary = "floating point addition operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -294,7 +335,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
|
|||
// AddIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
|
||||
def AddIOp : IntBinaryOp<"addi", [Commutative]> {
|
||||
let summary = "integer addition operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -418,7 +459,7 @@ def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
|
|||
// AndOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AndOp : IntArithmeticOp<"and", [Commutative]> {
|
||||
def AndOp : IntBinaryOp<"and", [Commutative]> {
|
||||
let summary = "integer binary and";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -1269,7 +1310,7 @@ def ConstantOp : Std_Op<"constant",
|
|||
// CopySignOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CopySignOp : FloatArithmeticOp<"copysign"> {
|
||||
def CopySignOp : FloatBinaryOp<"copysign"> {
|
||||
let summary = "A copysign operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -1384,11 +1425,49 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
|||
// DivFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DivFOp : FloatArithmeticOp<"divf"> {
|
||||
def DivFOp : FloatBinaryOp<"divf"> {
|
||||
let summary = "floating point division operation";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FmaFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FmaFOp : FloatTernaryOp<"fmaf"> {
|
||||
let summary = "floating point fused multipy-add operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
The `fmaf` operation takes three operands and returns one result, each of
|
||||
these is required to be the same type. This type may be a floating point
|
||||
scalar type, a vector whose element type is a floating point type, or a
|
||||
floating point tensor.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar fused multiply-add: d = a*b + c
|
||||
%d = fmaf %a, %b, %c : f64
|
||||
|
||||
// SIMD vector fused multiply-add, e.g. for Intel SSE.
|
||||
%i = fmaf %f, %g, %h : vector<4xf32>
|
||||
|
||||
// Tensor fused multiply-add.
|
||||
%w = fmaf %x, %y, %z : tensor<4x?xbf16>
|
||||
```
|
||||
|
||||
The semantics of the operation correspond to those of the `llvm.fma`
|
||||
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
|
||||
particular case of lowering to LLVM, this is guaranteed to lower
|
||||
to the `llvm.fma.*` intrinsic.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FPExtOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1854,7 +1933,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
|
|||
// MulFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MulFOp : FloatArithmeticOp<"mulf"> {
|
||||
def MulFOp : FloatBinaryOp<"mulf"> {
|
||||
let summary = "floating point multiplication operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -1891,7 +1970,7 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
|
|||
// MulIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
|
||||
def MulIOp : IntBinaryOp<"muli", [Commutative]> {
|
||||
let summary = "integer multiplication operation";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
@ -1933,7 +2012,7 @@ def NegFOp : FloatUnaryOp<"negf"> {
|
|||
// OrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OrOp : IntArithmeticOp<"or", [Commutative]> {
|
||||
def OrOp : IntBinaryOp<"or", [Commutative]> {
|
||||
let summary = "integer binary or";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -2040,7 +2119,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
|
|||
// RemFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def RemFOp : FloatArithmeticOp<"remf"> {
|
||||
def RemFOp : FloatBinaryOp<"remf"> {
|
||||
let summary = "floating point division remainder operation";
|
||||
}
|
||||
|
||||
|
@ -2141,7 +2220,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
|
|||
// ShiftLeftOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
|
||||
def ShiftLeftOp : IntBinaryOp<"shift_left"> {
|
||||
let summary = "integer left-shift";
|
||||
let description = [{
|
||||
The shift_left operation shifts an integer value to the left by a variable
|
||||
|
@ -2161,7 +2240,7 @@ def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
|
|||
// SignedDivIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
|
||||
def SignedDivIOp : IntBinaryOp<"divi_signed"> {
|
||||
let summary = "signed integer division operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -2196,7 +2275,7 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
|
|||
// SignedFloorDivIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
|
||||
def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> {
|
||||
let summary = "signed floor integer division operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -2225,7 +2304,7 @@ def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
|
|||
// SignedCeilDivIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
|
||||
def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> {
|
||||
let summary = "signed ceil integer division operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -2253,7 +2332,7 @@ def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
|
|||
// SignedRemIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
|
||||
def SignedRemIOp : IntBinaryOp<"remi_signed"> {
|
||||
let summary = "signed integer division remainder operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -2288,7 +2367,7 @@ def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
|
|||
// SignedShiftRightOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
|
||||
def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
|
||||
let summary = "signed integer right-shift";
|
||||
let description = [{
|
||||
The shift_right_signed operation shifts an integer value to the right by
|
||||
|
@ -2488,7 +2567,7 @@ def StoreOp : Std_Op<"store",
|
|||
// SubFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SubFOp : FloatArithmeticOp<"subf"> {
|
||||
def SubFOp : FloatBinaryOp<"subf"> {
|
||||
let summary = "floating point subtraction operation";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
@ -2497,7 +2576,7 @@ def SubFOp : FloatArithmeticOp<"subf"> {
|
|||
// SubIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SubIOp : IntArithmeticOp<"subi"> {
|
||||
def SubIOp : IntBinaryOp<"subi"> {
|
||||
let summary = "integer subtraction operation";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
@ -3173,7 +3252,7 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
|
|||
// UnsignedDivIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
|
||||
def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> {
|
||||
let summary = "unsigned integer division operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -3208,7 +3287,7 @@ def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
|
|||
// UnsignedRemIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
|
||||
def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> {
|
||||
let summary = "unsigned integer division remainder operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -3243,7 +3322,7 @@ def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
|
|||
// UnsignedShiftRightOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
|
||||
def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> {
|
||||
let summary = "unsigned integer right-shift";
|
||||
let description = [{
|
||||
The shift_right_unsigned operation shifts an integer value to the right by
|
||||
|
@ -3332,7 +3411,7 @@ def ViewOp : Std_Op<"view", [
|
|||
// XOrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||
def XOrOp : IntBinaryOp<"xor", [Commutative]> {
|
||||
let summary = "integer binary xor";
|
||||
let description = [{
|
||||
The `xor` operation takes two operands and returns one result, each of these
|
||||
|
|
|
@ -1662,6 +1662,7 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
|
|||
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
|
||||
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
|
||||
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
|
||||
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
|
||||
using Log10OpLowering =
|
||||
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
|
||||
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
|
||||
|
@ -3775,6 +3776,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|||
ExpOpLowering,
|
||||
Exp2OpLowering,
|
||||
FloorFOpLowering,
|
||||
FmaFOpLowering,
|
||||
GenericAtomicRMWOpLowering,
|
||||
LogOpLowering,
|
||||
Log10OpLowering,
|
||||
|
|
|
@ -158,6 +158,32 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
p << " : " << op->getResult(0).getType();
|
||||
}
|
||||
|
||||
/// A custom ternary operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
|
||||
assert(op->getNumOperands() == 3 && "ternary op should have three operands");
|
||||
assert(op->getNumResults() == 1 && "ternary op should have one result");
|
||||
|
||||
// If not all the operand and result types are the same, just use the
|
||||
// generic assembly form to avoid omitting information in printing.
|
||||
auto resultType = op->getResult(0).getType();
|
||||
if (op->getOperand(0).getType() != resultType ||
|
||||
op->getOperand(1).getType() != resultType ||
|
||||
op->getOperand(2).getType() != resultType) {
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< op->getOperand(0) << ", " << op->getOperand(1) << ", "
|
||||
<< op->getOperand(2);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
|
||||
// Now we can output only one type for all operands and the result.
|
||||
p << " : " << op->getResult(0).getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation printer that omits the "std." prefix from the
|
||||
/// operation names.
|
||||
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
|
|
|
@ -223,3 +223,16 @@ func @powf(%arg0 : f64) {
|
|||
%0 = math.powf %arg0, %arg0 : f64
|
||||
std.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fmaf(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: f32
|
||||
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
|
||||
func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
|
||||
// CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32
|
||||
%0 = fmaf %arg0, %arg0, %arg0 : f32
|
||||
// CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
|
||||
%1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
|
||||
std.return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue