Add operations needed to support lowering of AffineExpr to SPIR-V.

Lowering of CmpIOp, DivISOp, RemISOp, SubIOp and SelectOp to SPIR-V
dialect enables the lowering of operations generated by AffineExpr ->
StandardOps conversion into the SPIR-V dialect.

PiperOrigin-RevId: 280039204
This commit is contained in:
Mahesh Ravishankar 2019-11-12 13:19:33 -08:00 committed by A. Unique TensorFlower
parent 8082e3a687
commit 2be53603e9
2 changed files with 116 additions and 15 deletions

View File

@ -314,8 +314,9 @@ public:
return matchFailure();
}
// Use the bitwidth set in the value attribute to decide the result type of
// the SPIR-V constant operation since SPIR-V does not support index types.
// Use the bitwidth set in the value attribute to decide the result type
// of the SPIR-V constant operation since SPIR-V does not support index
// types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
@ -331,11 +332,47 @@ public:
}
};
/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
/// for this. If the integer operation is on variables of IndexType, the type of
/// the return value of the replacement operation differs from that of the
/// replaced operation. This is not handled in tablegen-based pattern
/// specification.
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public ConversionPattern {
public:
CmpIOpConversion(MLIRContext *context)
: ConversionPattern(CmpIOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpIOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor cmpIOpOperands(operands);
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
return matchSuccess();
DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
#undef DISPATCH
default:
break;
}
return matchFailure();
}
};
/// Convert integer binary operations to SPIR-V operations. Cannot use
/// tablegen for this. If the integer operation is on variables of IndexType,
/// the type of the return value of the replacement operation differs from
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public ConversionPattern {
public:
@ -396,9 +433,25 @@ public:
}
};
/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
/// Convert select -> spv.Select
class SelectOpConversion : public ConversionPattern {
public:
SelectOpConversion(MLIRContext *context)
: ConversionPattern(SelectOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
return matchSuccess();
}
};
/// Convert store -> spv.StoreOp. The operands of the replaced operation are
/// of IndexType while that of the replacement operation are of type i32. This
/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.store.
class StoreOpConversion final : public ConversionPattern {
@ -437,9 +490,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
patterns.insert<ConstantIndexOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
ReturnToSPIRVConversion, StoreOpConversion>(context);
patterns
.insert<ConstantIndexOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<DivISOp, spirv::SDivOp>,
IntegerOpConversion<RemISOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
context);
}
} // namespace mlir

View File

@ -58,3 +58,46 @@ func @constval() {
%4 = constant 1 : index
return
}
// CHECK-LABEL: @cmpiop
func @cmpiop(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.IEqual
%0 = cmpi "eq", %arg0, %arg1 : i32
// CHECK: spv.INotEqual
%1 = cmpi "ne", %arg0, %arg1 : i32
// CHECK: spv.SLessThan
%2 = cmpi "slt", %arg0, %arg1 : i32
// CHECK: spv.SLessThanEqual
%3 = cmpi "sle", %arg0, %arg1 : i32
// CHECK: spv.SGreaterThan
%4 = cmpi "sgt", %arg0, %arg1 : i32
// CHECK: spv.SGreaterThanEqual
%5 = cmpi "sge", %arg0, %arg1 : i32
return
}
// CHECK-LABEL: @select
func @selectOp(%arg0 : i32, %arg1 : i32) {
%0 = cmpi "sle", %arg0, %arg1 : i32
// CHECK: spv.Select
%1 = select %0, %arg0, %arg1 : i32
return
}
// CHECK-LABEL: @div_rem
func @div_rem(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.SDiv
%0 = divis %arg0, %arg1 : i32
// CHECK: spv.SMod
%1 = remis %arg0, %arg1 : i32
return
}
// CHECK-LABEL: @add_sub
func @add_sub(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.IAdd
%0 = addi %arg0, %arg1 : i32
// CHECK: spv.ISub
%1 = subi %arg0, %arg1 : i32
return
}