forked from OSchip/llvm-project
[spirv] Add a canonicalizer for BitcastOp.
Convert chained `spirv::BitcastOp` operations into one `spirv::BitcastOp` operation. Closes tensorflow/mlir#238 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/238 from denis0x0D:sandbox/canon_bitcast 4352ed4f81b959ec92f849c599e733b62a99c010 PiperOrigin-RevId: 281129234
This commit is contained in:
parent
563b5910a8
commit
6c77e59bfd
|
@ -98,6 +98,8 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
|
|||
|
||||
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
|
||||
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -652,8 +652,8 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
|||
|
||||
namespace {
|
||||
|
||||
// Combine chained `spirv::AccessChainOp` operations into one
|
||||
// `spirv::AccessChainOp` operation.
|
||||
/// Combines chained `spirv::AccessChainOp` operations into one
|
||||
/// `spirv::AccessChainOp` operation.
|
||||
struct CombineChainedAccessChain
|
||||
: public OpRewritePattern<spirv::AccessChainOp> {
|
||||
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
|
||||
|
@ -678,7 +678,7 @@ struct CombineChainedAccessChain
|
|||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
} // end anonymous namespace
|
||||
|
||||
void spirv::AccessChainOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
|
@ -771,6 +771,35 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Converts chained `spirv::BitcastOp` operations into one
|
||||
/// `spirv::BitcastOp` operation.
|
||||
struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
|
||||
using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
|
||||
bitcastOp.operand()->getDefiningOp());
|
||||
|
||||
if (!parentBitcastOp) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
|
||||
/*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
|
||||
bitcastOp.result()->getType(), parentBitcastOp.operand());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void spirv::BitcastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ConvertChainedBitcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BitFieldInsert
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2278,7 +2307,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|||
|
||||
return matchSuccess();
|
||||
}
|
||||
} // namespace
|
||||
} // end anonymous namespace
|
||||
|
||||
void spirv::SelectionOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
|
|
|
@ -134,6 +134,34 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Bitcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
|
||||
// CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
|
||||
// CHECK-NEXT: spv.ReturnValue %[[RESULT]]
|
||||
%0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
|
||||
%1 = spv.Bitcast %0 : vector<2xi32> to i64
|
||||
%2 = spv.Bitcast %1 : i64 to f64
|
||||
spv.ReturnValue %2 : f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
|
||||
// CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
|
||||
// CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
|
||||
// CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
|
||||
%0 = spv.Bitcast %arg0 : vector<2xf32> to i64
|
||||
%1 = spv.Bitcast %0 : i64 to f64
|
||||
spv.Store "Uniform" %arg1, %0 : i64
|
||||
spv.ReturnValue %1 : f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.selection
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue