From 6afda606094cf228997b64bc9d32b1305a81c024 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Wed, 8 Jun 2022 08:09:01 +0000 Subject: [PATCH] add ut test cases for parallel-if by check isomorphic funcgraphs --- tests/ut/cpp/pipeline/parse/parallel_if.cc | 109 +++++++++ .../gtest_input/pipeline/parse/parallel_if.py | 222 ++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 tests/ut/cpp/pipeline/parse/parallel_if.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/pipeline/parse/parallel_if.py diff --git a/tests/ut/cpp/pipeline/parse/parallel_if.cc b/tests/ut/cpp/pipeline/parse/parallel_if.cc new file mode 100644 index 00000000000..d682195407d --- /dev/null +++ b/tests/ut/cpp/pipeline/parse/parallel_if.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "utils/log_adapter.h" +#include "pipeline/jit/parse/parse.h" +#include "include/common/debug/draw.h" + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass.h" +#include "pipeline/jit/action.h" + +namespace mindspore { +namespace parse { +class TestParallelIf : public UT::Common { + public: + TestParallelIf() : getPyFun("gtest_input.pipeline.parse.parallel_if") {} + virtual void SetUp(); + virtual void TearDown(); + py::function GetPythonFunction(std::string function); + + bool CheckIsomorphic(FuncGraphPtr basic, FuncGraphPtr manual, std::vector opts = {}) { + opt::SubstitutionList transform(opts); + FuncGraphPairMapEquiv equiv_graph; + NodeMapEquiv equiv_node; + + opt::OptimizerPtr optimizer = std::make_shared("ut_test", std::make_shared()); + FuncGraphPtr basic_clone = BasicClone(basic); + transform(basic_clone, optimizer); + FuncGraphPtr manual_clone = BasicClone(manual); + transform(manual_clone, optimizer); + + return Isomorphic(basic_clone, manual_clone, &equiv_graph, &equiv_node); + } + + void CheckParallelIfTransform(const std::string &test_case) { + FuncGraphPtr basic_graph = getPyFun.CallAndParseRet(test_case, "basic"); + ASSERT_TRUE(basic_graph != nullptr); + FuncGraphPtr manual_graph = getPyFun.CallAndParseRet(test_case, "manual"); + ASSERT_TRUE(manual_graph != nullptr); + + pipeline::ResourcePtr res1 = std::make_shared(); + + tensor::TensorPtr x_tensor = std::make_shared(kFloat32->type_id(), std::vector{1}); + tensor::TensorPtr y_tensor = std::make_shared(kFloat32->type_id(), std::vector{1}); + + AbstractBasePtr abstract_x = abstract::FromValue(x_tensor, true); + AbstractBasePtr abstract_y = abstract::FromValue(y_tensor, true); + abstract::AbstractBasePtrList args_spec_list{abstract_x, abstract_y}; + + abstract::AnalysisResult result = pipeline::AbstractAnalyze(res1, basic_graph, args_spec_list); + auto new_basic_graph = pipeline::ProgramSpecialize(res1, basic_graph, result.context); + + pipeline::ResourcePtr res2 = std::make_shared(); + result = pipeline::AbstractAnalyze(res2, manual_graph, args_spec_list); + auto new_manual_graph = pipeline::ProgramSpecialize(res2, manual_graph, result.context); + + auto patterns = std::vector({irpass_lib_.inline_, irpass_lib_.switch_simplify_}); + ASSERT_TRUE(CheckIsomorphic(new_basic_graph, new_manual_graph, patterns)); + + abstract::AnalysisResultCacheMgr::GetInstance().Clear(); + abstract::AnalysisContext::ClearContext(); + } + public: + UT::PyFuncGraphFetcher getPyFun; + opt::irpass::OptimizeIRPassLib irpass_lib_; +}; + +void TestParallelIf::SetUp() { UT::InitPythonPath(); } + +void TestParallelIf::TearDown() {} + +// Feature: Parallel if transformation +// Description: Check parallel if transformatin for test code with single if/else. +// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed. +TEST_F(TestParallelIf, SimpleIf) { CheckParallelIfTransform("test_simple_if"); } + +// Feature: Parallel if transformation +// Description: Check parallel if transformatin for test code with if-by-if. +// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed. +TEST_F(TestParallelIf, IfByIf) { CheckParallelIfTransform("test_if_by_if"); } + +// Feature: Parallel if transformation +// Description: Check parallel if transformatin for test code with if-in-if. +// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed. +TEST_F(TestParallelIf, IfInIf) { CheckParallelIfTransform("test_if_in_if"); } + +// Feature: Parallel if transformation +// Description: Check parallel if transformatin for test code with if-elif-else. +// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed. +TEST_F(TestParallelIf, IfElifElse) { CheckParallelIfTransform("test_if_elif_else"); } +} // namespace parse +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parallel_if.py b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parallel_if.py new file mode 100644 index 00000000000..1cbba8e0136 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parallel_if.py @@ -0,0 +1,222 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +file: parallel_if.py +""" +from mindspore.ops import functional as F +from mindspore._extends.parse import standard_method + + +class FnDict: + def __init__(self): + self.fn_dict = {} + + def __call__(self, fn): + self.fn_dict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fn_dict.get(name) + + +# pylint: disable=unused-variable +# disable pylint unused variable for basic/manual which are used by decorator. +def test_simple_if(tag): + """ + Feature: Parallel if transformation + Description: simple if with single if/else + Expectation: funcgraph parsed and manual constructed should be isomorphic. + """ + fns = FnDict() + @fns + def basic(x, y): + if x > y: + x = x + y + else: + x = x - y + return x + x + + @fns + def manual(x, y): + def after(a_x): + return a_x + a_x + + def true_branch(): + return x + y + + def false_branch(): + return x - y + + cond = standard_method.bool_(x > y) + + switch_node = F.switch(cond, true_branch, false_branch) + result = switch_node() + return after(result) + + return fns[tag] + + +def test_if_by_if(tag): + """ + Feature: Parallel if transformation + Description: if/else after if/else + Expectation: funcgraph parsed and manual constructed should be isomorphic. + """ + fns = FnDict() + @fns + def basic(x, y): + if x > y: + x = x + y + else: + x = x - y + if x < y: + y = x * y + else: + y = x + y + return x + y + + @fns + def manual(x, y): + # first if + def true_branch1(): + return x + y + + def false_branch1(): + return x - y + + cond1 = standard_method.bool_(x > y) + switch_node = F.switch(cond1, true_branch1, false_branch1) + result1 = switch_node() + + cond2 = standard_method.bool_(result1 < y) + + # second if + def true_branch2(): + return result1 * y + + def false_branch2(): + return result1 + y + + def after2(a_x, a_y): + return a_x + a_y + + def after1(): + switch_node = F.switch(cond2, true_branch2, false_branch2) + result2 = switch_node() + return after2(result1, result2) + + return after1() + + return fns[tag] + + +def test_if_in_if(tag): + """ + Feature: Parallel if transformation + Description: if/else in if + Expectation: funcgraph parsed and manual constructed should be isomorphic. + """ + fns = FnDict() + @fns + def basic(x, y): + if x >= y: + if x > y: + x = x + y + else: + x = x - y + else: + x = x * y + return x + y + + @fns + def manual(x, y): + # inner if/else + def true_branch2(): + return x + y + + def false_branch2(): + return x - y + + def after2(a_x): + return a_x + + # outer if/else + def after1(a_x): + return a_x + y + + def true_branch1(): + cond2 = standard_method.bool_(x > y) + switch_node = F.switch(cond2, true_branch2, false_branch2) + result2 = switch_node() + return after2(result2) + + def false_branch1(): + return x * y + + cond1 = standard_method.bool_(x >= y) + switch_node = F.switch(cond1, true_branch1, false_branch1) + result1 = switch_node() + return after1(result1) + + return fns[tag] + + +def test_if_elif_else(tag): + """ + Feature: Parallel if transformation + Description: if/elif/else which can be treated as if/else{if/else}. + Expectation: funcgraph parsed and manual constructed should be isomorphic. + """ + fns = FnDict() + @fns + def basic(x, y): + if x > y: + out = x + y + elif x == y: + out = x - y + else: + out = x * y + return out + out + + @fns + def manual(x, y): + # elif/else part + def true_branch2(): + return x - y + + def false_branch2(): + return x * y + + def after2(out): + return out + + # if part + def after1(out): + return out + out + + def true_branch1(): + return x + y + + def false_branch1(): + cond2 = standard_method.bool_(x == y) + switch_node = F.switch(cond2, true_branch2, false_branch2) + result2 = switch_node() + return after2(result2) + + cond1 = standard_method.bool_(x > y) + switch_node = F.switch(cond1, true_branch1, false_branch1) + result1 = switch_node() + return after1(result1) + + return fns[tag]