forked from OSchip/llvm-project
[mlir][spirv] Handle another form of folding comparsion into clamp
Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D121227
This commit is contained in:
parent
ffb410d3f9
commit
55a4df9c14
|
@ -40,8 +40,7 @@ def ConvertLogicalNotOfLogicalNotEqual : Pat<
|
|||
(SPV_LogicalEqualOp $lhs, $rhs)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Re-write spv.Select + spv.<less_than_op> to a suitable variant of
|
||||
// spv.<glsl_clamp_op>
|
||||
// spv.Select -> spv.GLSL.*Clamp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ValuesAreEqual : Constraint<CPred<"$0 == $1">>;
|
||||
|
@ -53,7 +52,9 @@ foreach CmpClampPair = [
|
|||
[SPV_SLessThanEqualOp, SPV_GLSLSClampOp],
|
||||
[SPV_ULessThanOp, SPV_GLSLUClampOp],
|
||||
[SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in {
|
||||
def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
|
||||
|
||||
// Detect: $min < $input, $input < $max
|
||||
def ConvertComparisonIntoClamp1_#CmpClampPair[0] : Pat<
|
||||
(SPV_SelectOp
|
||||
(CmpClampPair[0]
|
||||
(SPV_SelectOp:$middle0
|
||||
|
@ -67,4 +68,16 @@ def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
|
|||
$max),
|
||||
(CmpClampPair[1] $input, $min, $max),
|
||||
[(ValuesAreEqual $middle0, $middle1)]>;
|
||||
|
||||
// Detect: $input < $min, $max < $input
|
||||
def ConvertComparisonIntoClamp2_#CmpClampPair[0] : Pat<
|
||||
(SPV_SelectOp
|
||||
(CmpClampPair[0] $max, $input),
|
||||
$max,
|
||||
(SPV_SelectOp
|
||||
(CmpClampPair[0] $input, $min),
|
||||
$min,
|
||||
$input
|
||||
)),
|
||||
(CmpClampPair[1] $input, $min, $max)>;
|
||||
}
|
||||
|
|
|
@ -23,12 +23,18 @@ namespace {
|
|||
namespace mlir {
|
||||
namespace spirv {
|
||||
void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results) {
|
||||
results.add<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
|
||||
ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
|
||||
ConvertComparisonIntoClampSPV_SLessThanOp,
|
||||
ConvertComparisonIntoClampSPV_SLessThanEqualOp,
|
||||
ConvertComparisonIntoClampSPV_ULessThanOp,
|
||||
ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
|
||||
results.add<ConvertComparisonIntoClamp1_SPV_FOrdLessThanOp,
|
||||
ConvertComparisonIntoClamp1_SPV_FOrdLessThanEqualOp,
|
||||
ConvertComparisonIntoClamp1_SPV_SLessThanOp,
|
||||
ConvertComparisonIntoClamp1_SPV_SLessThanEqualOp,
|
||||
ConvertComparisonIntoClamp1_SPV_ULessThanOp,
|
||||
ConvertComparisonIntoClamp1_SPV_ULessThanEqualOp,
|
||||
ConvertComparisonIntoClamp2_SPV_FOrdLessThanOp,
|
||||
ConvertComparisonIntoClamp2_SPV_FOrdLessThanEqualOp,
|
||||
ConvertComparisonIntoClamp2_SPV_SLessThanOp,
|
||||
ConvertComparisonIntoClamp2_SPV_SLessThanEqualOp,
|
||||
ConvertComparisonIntoClamp2_SPV_ULessThanOp,
|
||||
ConvertComparisonIntoClamp2_SPV_ULessThanEqualOp>(
|
||||
results.getContext());
|
||||
}
|
||||
} // namespace spirv
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
// RUN: mlir-opt -split-input-file -spirv-canonicalize-glsl %s | FileCheck %s
|
||||
|
||||
// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
|
||||
func @clamp_fordlessthan(%input: f32) -> f32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0.5 : f32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 1.0 : f32
|
||||
|
||||
// CHECK-LABEL: func @clamp_fordlessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
|
||||
func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.FOrdLessThan %min, %input : f32
|
||||
%mid = spv.Select %0, %input, %min : i1, f32
|
||||
|
@ -19,13 +15,24 @@ func @clamp_fordlessthan(%input: f32) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32)
|
||||
func @clamp_fordlessthanequal(%input: f32) -> f32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0.5 : f32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 1.0 : f32
|
||||
// CHECK-LABEL: func @clamp_fordlessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
|
||||
func @clamp_fordlessthan(%input: f32, %min: f32, %max: f32) -> f32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.FOrdLessThan %input, %min : f32
|
||||
%mid = spv.Select %0, %min, %input : i1, f32
|
||||
%1 = spv.FOrdLessThan %max, %input : f32
|
||||
%2 = spv.Select %1, %max, %mid : i1, f32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_fordlessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
|
||||
func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.FOrdLessThanEqual %min, %input : f32
|
||||
%mid = spv.Select %0, %input, %min : i1, f32
|
||||
|
@ -38,13 +45,24 @@ func @clamp_fordlessthanequal(%input: f32) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32)
|
||||
func @clamp_slessthan(%input: si32) -> si32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0 : si32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 10 : si32
|
||||
// CHECK-LABEL: func @clamp_fordlessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32)
|
||||
func @clamp_fordlessthanequal(%input: f32, %min: f32, %max: f32) -> f32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.FOrdLessThanEqual %input, %min : f32
|
||||
%mid = spv.Select %0, %min, %input : i1, f32
|
||||
%1 = spv.FOrdLessThanEqual %max, %input : f32
|
||||
%2 = spv.Select %1, %max, %mid : i1, f32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_slessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
|
||||
func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.SLessThan %min, %input : si32
|
||||
%mid = spv.Select %0, %input, %min : i1, si32
|
||||
|
@ -57,13 +75,24 @@ func @clamp_slessthan(%input: si32) -> si32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32)
|
||||
func @clamp_slessthanequal(%input: si32) -> si32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0 : si32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 10 : si32
|
||||
// CHECK-LABEL: func @clamp_slessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
|
||||
func @clamp_slessthan(%input: si32, %min: si32, %max: si32) -> si32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.SLessThan %input, %min : si32
|
||||
%mid = spv.Select %0, %min, %input : i1, si32
|
||||
%1 = spv.SLessThan %max, %input : si32
|
||||
%2 = spv.Select %1, %max, %mid : i1, si32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : si32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_slessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
|
||||
func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.SLessThanEqual %min, %input : si32
|
||||
%mid = spv.Select %0, %input, %min : i1, si32
|
||||
|
@ -76,13 +105,24 @@ func @clamp_slessthanequal(%input: si32) -> si32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32)
|
||||
func @clamp_ulessthan(%input: i32) -> i32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0 : i32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 10 : i32
|
||||
// CHECK-LABEL: func @clamp_slessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: si32, %[[MIN:.*]]: si32, %[[MAX:.*]]: si32)
|
||||
func @clamp_slessthanequal(%input: si32, %min: si32, %max: si32) -> si32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.SLessThanEqual %input, %min : si32
|
||||
%mid = spv.Select %0, %min, %input : i1, si32
|
||||
%1 = spv.SLessThanEqual %max, %input : si32
|
||||
%2 = spv.Select %1, %max, %mid : i1, si32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : si32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_ulessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
|
||||
func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.ULessThan %min, %input : i32
|
||||
%mid = spv.Select %0, %input, %min : i1, i32
|
||||
|
@ -95,13 +135,24 @@ func @clamp_ulessthan(%input: i32) -> i32 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32)
|
||||
func @clamp_ulessthanequal(%input: i32) -> i32 {
|
||||
// CHECK: %[[MIN:.*]] = spv.Constant
|
||||
%min = spv.Constant 0 : i32
|
||||
// CHECK: %[[MAX:.*]] = spv.Constant
|
||||
%max = spv.Constant 10 : i32
|
||||
// CHECK-LABEL: func @clamp_ulessthan
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
|
||||
func @clamp_ulessthan(%input: i32, %min: i32, %max: i32) -> i32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.ULessThan %input, %min : i32
|
||||
%mid = spv.Select %0, %min, %input : i1, i32
|
||||
%1 = spv.ULessThan %max, %input : i32
|
||||
%2 = spv.Select %1, %max, %mid : i1, i32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_ulessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
|
||||
func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.ULessThanEqual %min, %input : i32
|
||||
%mid = spv.Select %0, %input, %min : i1, i32
|
||||
|
@ -111,3 +162,18 @@ func @clamp_ulessthanequal(%input: i32) -> i32 {
|
|||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp_ulessthanequal
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: i32, %[[MIN:.*]]: i32, %[[MAX:.*]]: i32)
|
||||
func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
|
||||
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
|
||||
%0 = spv.ULessThanEqual %input, %min : i32
|
||||
%mid = spv.Select %0, %min, %input : i1, i32
|
||||
%1 = spv.ULessThanEqual %max, %input : i32
|
||||
%2 = spv.Select %1, %max, %mid : i1, i32
|
||||
|
||||
// CHECK-NEXT: spv.ReturnValue [[RES]]
|
||||
spv.ReturnValue %2 : i32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue