From 5c56a45cced1b4825c9928210d466a6ff5116cc8 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Tue, 3 Jan 2023 14:21:07 +0800 Subject: [PATCH] Fixed the bug of cast eliminate. --- .../optimizer/irpass/cast_eliminate.cc | 68 ++++++++++++++++++- .../optimizer/irpass/cast_eliminate.h | 21 +++++- tests/st/ops/gpu/test_cast_op.py | 26 ++++++- tests/ut/cpp/optimizer/lib_test.cc | 8 --- .../gtest_input/optimizer/opt_test.py | 16 ----- 5 files changed, 109 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc index 501fe3e8bc4..e6ea4d96283 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -86,12 +86,71 @@ void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { } } +bool TwoCastEliminater::CheckTwoTypes(const std::map &type_map, TypeId type1, TypeId type2) { + auto type1_iter = type_map.find(type1); + auto type2_iter = type_map.find(type2); + if (type1_iter != type_map.end() && type2_iter != type_map.end()) { + return type1_iter->second <= type2_iter->second; + } + return false; +} + +bool TwoCastEliminater::CheckThreeTypes(const std::map &type_map, TypeId type1, TypeId type2, + TypeId type3) { + auto type1_iter = type_map.find(type1); + auto type2_iter = type_map.find(type2); + auto type3_iter = type_map.find(type3); + if (type1_iter != type_map.end() && type2_iter != type_map.end() && type3_iter != type_map.end()) { + return type1_iter->second <= type2_iter->second && type2_iter->second <= type3_iter->second; + } + return false; +} + +// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -> {prim::kPrimCast, X, T} +// x_type <= y_type <= t_type or x_type >= y_type >= t_type +bool TwoCastEliminater::CheckTypesIsIncrementalOrDecreasing() { + auto x_type = x_->Type(); + if (x_type->isa()) { + x_type = x_type->cast()->element(); + } + + auto y_type = GetValueNode(y_); + MS_EXCEPTION_IF_NULL(y_type); + if (y_type->isa()) { + y_type = y_type->cast()->element(); + } + + auto t_type = GetValueNode(t_); + MS_EXCEPTION_IF_NULL(t_type); + if (t_type->isa()) { + t_type = t_type->cast()->element(); + } + auto x_type_id = x_type->type_id(); + auto y_type_id = y_type->type_id(); + auto t_type_id = t_type->type_id(); + if (y_type_id == t_type_id) { + return true; + } + // If the precision is incremental or decreasing, the cast can be eliminated. + // x_type <= y_type + bool incremental = CheckTwoTypes(int_map_, x_type_id, y_type_id) || CheckTwoTypes(uint_map_, x_type_id, y_type_id) || + CheckTwoTypes(float_map_, x_type_id, y_type_id); + // x_type >= y_type >= t_type + bool decreasing = CheckThreeTypes(int_map_, t_type_id, y_type_id, x_type_id) || + CheckThreeTypes(uint_map_, t_type_id, y_type_id, x_type_id) || + CheckThreeTypes(float_map_, t_type_id, y_type_id, x_type_id); + return incremental || decreasing; +} + // {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { Reset(); AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); - if (x_ != nullptr && t_ != nullptr) { + if (x_ == nullptr || t_ == nullptr || y_ == nullptr) { + return nullptr; + } + if (CheckTypesIsIncrementalOrDecreasing()) { auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); @@ -106,10 +165,13 @@ void TwoCastEliminater::Visit(const AnfNodePtr &node) { auto cnode = node->cast(); // {prim::kPrimCast, X, Y} constexpr size_t cast_size = 3; + constexpr size_t cast_data_index = 1; + constexpr size_t cast_type_index = 2; if (cnode->size() != cast_size) { return; } - x_ = cnode->input(1); + x_ = cnode->input(cast_data_index); + y_ = cnode->input(cast_type_index); } else { t_ = node; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h index 46c86ab7167..1703d321f7e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#include #include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -46,10 +47,26 @@ class TwoCastEliminater : public AnfVisitor { void Reset() { x_ = nullptr; t_ = nullptr; + y_ = nullptr; } private: - AnfNodePtr x_{nullptr}, t_{nullptr}; + bool CheckTypesIsIncrementalOrDecreasing(); + bool CheckTwoTypes(const std::map &type_map, TypeId type1, TypeId type2); + bool CheckThreeTypes(const std::map &type_map, TypeId type1, TypeId type2, TypeId type3); + std::map int_map_ = { + {kNumberTypeInt, 0}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 3}, {kNumberTypeInt64, 4}}; + std::map uint_map_ = {{kNumberTypeUInt, 0}, + {kNumberTypeUInt8, 1}, + {kNumberTypeUInt16, 2}, + {kNumberTypeUInt32, 3}, + {kNumberTypeUInt64, 4}}; + std::map float_map_ = {{kNumberTypeFloat, 0}, + {kNumberTypeFloat16, 1}, + {kNumberTypeFloat32, 2}, + {kNumberTypeFloat64, 3}, + {kNumberTypeDouble, 4}}; + AnfNodePtr x_{nullptr}, t_{nullptr}, y_{nullptr}; }; class CastEliminater : public OptimizerCaller { diff --git a/tests/st/ops/gpu/test_cast_op.py b/tests/st/ops/gpu/test_cast_op.py index 576fd0735b5..28fe88d3ed3 100644 --- a/tests/st/ops/gpu/test_cast_op.py +++ b/tests/st/ops/gpu/test_cast_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2023 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ from mindspore.common.tensor import Tensor from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner +from mindspore import nn +import mindspore as ms class Net(Cell): @@ -619,3 +621,25 @@ def test_cast32(): assert (output[0].asnumpy() == expected).all() type1 = output[1].asnumpy().dtype assert type1 == 'float64' + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_two_cast(): + """ + Feature: test cast eliminate. + Description: test two cast eliminater. + Expectation: no exception. + """ + class TwoCastNet(nn.Cell): + def construct(self, x): + return ms.ops.Cast()(ms.ops.Cast()(x, ms.int32), ms.float32) + + ms.set_context(mode=context.GRAPH_MODE) + input_x = ms.Tensor([1.1], ms.float32) + res = TwoCastNet()(input_x) + expected = [1.] + assert (res.asnumpy() == expected).all() + type1 = res.asnumpy().dtype + assert type1 == 'float32' diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 3a21f81a97e..6ca933bf248 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -238,14 +238,6 @@ TEST_F(TestOptLib, elim_two_reshape) { ASSERT_TRUE(CheckOpt(before, after, patterns)); } -TEST_F(TestOptLib, elim_two_cast) { - FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before"); - FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after"); - - auto patterns = std::vector({irpass.cast_eliminate_}); - ASSERT_TRUE(CheckOpt(before, after, patterns)); -} - TEST_F(TestOptLib, test_elim_transpose) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after"); diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 59f788a52da..3b5feb82100 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -488,22 +488,6 @@ def elim_two_reshape(tag): return fns[tag] -def elim_two_cast(tag): - """ elim_two_cast """ - fns = FnDict() - cast = P.Cast() - - @fns - def before(x, a, b): - return cast(cast(x, a), b) - - @fns - def after(x, a, b): - return cast(x, b) - - return fns[tag] - - def test_elim_transpose(tag): """ test_elim_transpose """ fns = FnDict()