[spirv] Add a canonicalizer for BitcastOp.

Convert chained `spirv::BitcastOp` operations into
one `spirv::BitcastOp` operation.

Closes tensorflow/mlir#238

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/238 from denis0x0D:sandbox/canon_bitcast 4352ed4f81b959ec92f849c599e733b62a99c010
PiperOrigin-RevId: 281129234
This commit is contained in:
Denis Khalikov 2019-11-18 12:36:16 -08:00 committed by A. Unique TensorFlower
parent 563b5910a8
commit 6c77e59bfd
3 changed files with 63 additions and 4 deletions

View File

@ -98,6 +98,8 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
let hasCanonicalizer = 1;
}
// -----

View File

@ -652,8 +652,8 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
namespace {
// Combine chained `spirv::AccessChainOp` operations into one
// `spirv::AccessChainOp` operation.
/// Combines chained `spirv::AccessChainOp` operations into one
/// `spirv::AccessChainOp` operation.
struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
@ -678,7 +678,7 @@ struct CombineChainedAccessChain
return matchSuccess();
}
};
} // namespace
} // end anonymous namespace
void spirv::AccessChainOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
@ -771,6 +771,35 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
return success();
}
namespace {
/// Converts chained `spirv::BitcastOp` operations into one
/// `spirv::BitcastOp` operation.
struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
PatternRewriter &rewriter) const override {
auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
bitcastOp.operand()->getDefiningOp());
if (!parentBitcastOp) {
return matchFailure();
}
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
/*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
bitcastOp.result()->getType(), parentBitcastOp.operand());
return matchSuccess();
}
};
} // end anonymous namespace
void spirv::BitcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvertChainedBitcast>(context);
}
//===----------------------------------------------------------------------===//
// spv.BitFieldInsert
//===----------------------------------------------------------------------===//
@ -2278,7 +2307,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
return matchSuccess();
}
} // namespace
} // end anonymous namespace
void spirv::SelectionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {

View File

@ -134,6 +134,34 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a
// -----
//===----------------------------------------------------------------------===//
// spv.Bitcast
//===----------------------------------------------------------------------===//
func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 {
// CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
// CHECK-NEXT: spv.ReturnValue %[[RESULT]]
%0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32>
%1 = spv.Bitcast %0 : vector<2xi32> to i64
%2 = spv.Bitcast %1 : i64 to f64
spv.ReturnValue %2 : f64
}
// -----
func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr<i64, Uniform>) -> f64 {
// CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64
// CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64
// CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]]
// CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
%0 = spv.Bitcast %arg0 : vector<2xf32> to i64
%1 = spv.Bitcast %0 : i64 to f64
spv.Store "Uniform" %arg1, %0 : i64
spv.ReturnValue %1 : f64
}
// -----
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//