[mlir] Add fma operation to std dialect

Will remove `vector.fma` operation in the followup CLs.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96801
This commit is contained in:
Eugene Zhulenev 2021-02-17 08:34:33 -08:00
parent fb19400d4e
commit 519f5917b4
4 changed files with 146 additions and 26 deletions

View File

@ -103,7 +103,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
// types.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect,
@ -122,6 +122,32 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
}];
}
// Base class for standard binary arithmetic operations.
class ArithmeticBinaryOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic, traits> {
let parser = [{
return impl::parseOneResultSameOperandTypeOp(parser, result);
}];
let printer = [{
return printStandardBinaryOp(this->getOperation(), p);
}];
}
// Base class for standard ternary arithmetic operations.
class ArithmeticTernaryOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic, traits> {
let parser = [{
return impl::parseOneResultSameOperandTypeOp(parser, result);
}];
let printer = [{
return printStandardTernaryOp(this->getOperation(), p);
}];
}
// Base class for standard arithmetic operations on integers, vectors and
// tensors thereof. This operation takes two operands and returns one result,
// each of these is required to be of the same type. This type may be an
@ -130,8 +156,8 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
//
// <op>i %0, %1 : i32
//
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic,
class IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticBinaryOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@ -145,12 +171,27 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
//
// <op>f %0, %1 : f32
//
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticOp<mnemonic,
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticBinaryOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
// Base class for standard arithmetic ternary operations on floats, vectors and
// tensors thereof. This operation has three operands and returns one result,
// each of these is required to be of the same type. This type may be a
// floating point scalar type, a vector whose element type is a floating point
// type, or a floating point tensor. The custom assembly form of the operation
// is as follows
//
// <op> %0, %1, %2 : f32
//
class FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
ArithmeticTernaryOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>;
// Base class for memref allocating ops: alloca and alloc.
//
// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@ -257,7 +298,7 @@ def AbsFOp : FloatUnaryOp<"absf"> {
// AddFOp
//===----------------------------------------------------------------------===//
def AddFOp : FloatArithmeticOp<"addf"> {
def AddFOp : FloatBinaryOp<"addf"> {
let summary = "floating point addition operation";
let description = [{
Syntax:
@ -294,7 +335,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
// AddIOp
//===----------------------------------------------------------------------===//
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
def AddIOp : IntBinaryOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Syntax:
@ -418,7 +459,7 @@ def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
// AndOp
//===----------------------------------------------------------------------===//
def AndOp : IntArithmeticOp<"and", [Commutative]> {
def AndOp : IntBinaryOp<"and", [Commutative]> {
let summary = "integer binary and";
let description = [{
Syntax:
@ -1269,7 +1310,7 @@ def ConstantOp : Std_Op<"constant",
// CopySignOp
//===----------------------------------------------------------------------===//
def CopySignOp : FloatArithmeticOp<"copysign"> {
def CopySignOp : FloatBinaryOp<"copysign"> {
let summary = "A copysign operation";
let description = [{
Syntax:
@ -1384,11 +1425,49 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
// DivFOp
//===----------------------------------------------------------------------===//
def DivFOp : FloatArithmeticOp<"divf"> {
def DivFOp : FloatBinaryOp<"divf"> {
let summary = "floating point division operation";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// FmaFOp
//===----------------------------------------------------------------------===//
def FmaFOp : FloatTernaryOp<"fmaf"> {
let summary = "floating point fused multipy-add operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type
```
The `fmaf` operation takes three operands and returns one result, each of
these is required to be the same type. This type may be a floating point
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
Example:
```mlir
// Scalar fused multiply-add: d = a*b + c
%d = fmaf %a, %b, %c : f64
// SIMD vector fused multiply-add, e.g. for Intel SSE.
%i = fmaf %f, %g, %h : vector<4xf32>
// Tensor fused multiply-add.
%w = fmaf %x, %y, %z : tensor<4x?xbf16>
```
The semantics of the operation correspond to those of the `llvm.fma`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
particular case of lowering to LLVM, this is guaranteed to lower
to the `llvm.fma.*` intrinsic.
}];
}
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
@ -1854,7 +1933,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
// MulFOp
//===----------------------------------------------------------------------===//
def MulFOp : FloatArithmeticOp<"mulf"> {
def MulFOp : FloatBinaryOp<"mulf"> {
let summary = "floating point multiplication operation";
let description = [{
Syntax:
@ -1891,7 +1970,7 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
// MulIOp
//===----------------------------------------------------------------------===//
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
def MulIOp : IntBinaryOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasFolder = 1;
}
@ -1933,7 +2012,7 @@ def NegFOp : FloatUnaryOp<"negf"> {
// OrOp
//===----------------------------------------------------------------------===//
def OrOp : IntArithmeticOp<"or", [Commutative]> {
def OrOp : IntBinaryOp<"or", [Commutative]> {
let summary = "integer binary or";
let description = [{
Syntax:
@ -2040,7 +2119,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
// RemFOp
//===----------------------------------------------------------------------===//
def RemFOp : FloatArithmeticOp<"remf"> {
def RemFOp : FloatBinaryOp<"remf"> {
let summary = "floating point division remainder operation";
}
@ -2141,7 +2220,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
// ShiftLeftOp
//===----------------------------------------------------------------------===//
def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
def ShiftLeftOp : IntBinaryOp<"shift_left"> {
let summary = "integer left-shift";
let description = [{
The shift_left operation shifts an integer value to the left by a variable
@ -2161,7 +2240,7 @@ def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
// SignedDivIOp
//===----------------------------------------------------------------------===//
def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
def SignedDivIOp : IntBinaryOp<"divi_signed"> {
let summary = "signed integer division operation";
let description = [{
Syntax:
@ -2196,7 +2275,7 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
// SignedFloorDivIOp
//===----------------------------------------------------------------------===//
def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> {
let summary = "signed floor integer division operation";
let description = [{
Syntax:
@ -2225,7 +2304,7 @@ def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
// SignedCeilDivIOp
//===----------------------------------------------------------------------===//
def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> {
let summary = "signed ceil integer division operation";
let description = [{
Syntax:
@ -2253,7 +2332,7 @@ def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
// SignedRemIOp
//===----------------------------------------------------------------------===//
def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
def SignedRemIOp : IntBinaryOp<"remi_signed"> {
let summary = "signed integer division remainder operation";
let description = [{
Syntax:
@ -2288,7 +2367,7 @@ def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
// SignedShiftRightOp
//===----------------------------------------------------------------------===//
def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
let summary = "signed integer right-shift";
let description = [{
The shift_right_signed operation shifts an integer value to the right by
@ -2488,7 +2567,7 @@ def StoreOp : Std_Op<"store",
// SubFOp
//===----------------------------------------------------------------------===//
def SubFOp : FloatArithmeticOp<"subf"> {
def SubFOp : FloatBinaryOp<"subf"> {
let summary = "floating point subtraction operation";
let hasFolder = 1;
}
@ -2497,7 +2576,7 @@ def SubFOp : FloatArithmeticOp<"subf"> {
// SubIOp
//===----------------------------------------------------------------------===//
def SubIOp : IntArithmeticOp<"subi"> {
def SubIOp : IntBinaryOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
}
@ -3173,7 +3252,7 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
// UnsignedDivIOp
//===----------------------------------------------------------------------===//
def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> {
let summary = "unsigned integer division operation";
let description = [{
Syntax:
@ -3208,7 +3287,7 @@ def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
// UnsignedRemIOp
//===----------------------------------------------------------------------===//
def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> {
let summary = "unsigned integer division remainder operation";
let description = [{
Syntax:
@ -3243,7 +3322,7 @@ def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
// UnsignedShiftRightOp
//===----------------------------------------------------------------------===//
def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> {
let summary = "unsigned integer right-shift";
let description = [{
The shift_right_unsigned operation shifts an integer value to the right by
@ -3332,7 +3411,7 @@ def ViewOp : Std_Op<"view", [
// XOrOp
//===----------------------------------------------------------------------===//
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
def XOrOp : IntBinaryOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let description = [{
The `xor` operation takes two operands and returns one result, each of these

View File

@ -1662,6 +1662,7 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
using Log10OpLowering =
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@ -3775,6 +3776,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
ExpOpLowering,
Exp2OpLowering,
FloorFOpLowering,
FmaFOpLowering,
GenericAtomicRMWOpLowering,
LogOpLowering,
Log10OpLowering,

View File

@ -158,6 +158,32 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
p << " : " << op->getResult(0).getType();
}
/// A custom ternary operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumOperands() == 3 && "ternary op should have three operands");
assert(op->getNumResults() == 1 && "ternary op should have one result");
// If not all the operand and result types are the same, just use the
// generic assembly form to avoid omitting information in printing.
auto resultType = op->getResult(0).getType();
if (op->getOperand(0).getType() != resultType ||
op->getOperand(1).getType() != resultType ||
op->getOperand(2).getType() != resultType) {
p.printGenericOp(op);
return;
}
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< op->getOperand(0) << ", " << op->getOperand(1) << ", "
<< op->getOperand(2);
p.printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
p << " : " << op->getResult(0).getType();
}
/// A custom cast operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {

View File

@ -223,3 +223,16 @@ func @powf(%arg0 : f64) {
%0 = math.powf %arg0, %arg0 : f64
std.return
}
// -----
// CHECK-LABEL: func @fmaf(
// CHECK-SAME: %[[ARG0:.*]]: f32
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
// CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32
%0 = fmaf %arg0, %arg0, %arg0 : f32
// CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
std.return
}