add ut test cases for parallel-if by check isomorphic funcgraphs

This commit is contained in:
zhousiyi 2022-06-08 08:09:01 +00:00
parent ccdcaae2a2
commit 6afda60609
2 changed files with 331 additions and 0 deletions

View File

@ -0,0 +1,109 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <string>
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/parse/parse.h"
#include "include/common/debug/draw.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass.h"
#include "pipeline/jit/action.h"
namespace mindspore {
namespace parse {
class TestParallelIf : public UT::Common {
public:
TestParallelIf() : getPyFun("gtest_input.pipeline.parse.parallel_if") {}
virtual void SetUp();
virtual void TearDown();
py::function GetPythonFunction(std::string function);
bool CheckIsomorphic(FuncGraphPtr basic, FuncGraphPtr manual, std::vector<opt::SubstitutionPtr> opts = {}) {
opt::SubstitutionList transform(opts);
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
opt::OptimizerPtr optimizer = std::make_shared<opt::Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
FuncGraphPtr basic_clone = BasicClone(basic);
transform(basic_clone, optimizer);
FuncGraphPtr manual_clone = BasicClone(manual);
transform(manual_clone, optimizer);
return Isomorphic(basic_clone, manual_clone, &equiv_graph, &equiv_node);
}
void CheckParallelIfTransform(const std::string &test_case) {
FuncGraphPtr basic_graph = getPyFun.CallAndParseRet(test_case, "basic");
ASSERT_TRUE(basic_graph != nullptr);
FuncGraphPtr manual_graph = getPyFun.CallAndParseRet(test_case, "manual");
ASSERT_TRUE(manual_graph != nullptr);
pipeline::ResourcePtr res1 = std::make_shared<pipeline::Resource>();
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1});
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1});
AbstractBasePtr abstract_x = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_y = abstract::FromValue(y_tensor, true);
abstract::AbstractBasePtrList args_spec_list{abstract_x, abstract_y};
abstract::AnalysisResult result = pipeline::AbstractAnalyze(res1, basic_graph, args_spec_list);
auto new_basic_graph = pipeline::ProgramSpecialize(res1, basic_graph, result.context);
pipeline::ResourcePtr res2 = std::make_shared<pipeline::Resource>();
result = pipeline::AbstractAnalyze(res2, manual_graph, args_spec_list);
auto new_manual_graph = pipeline::ProgramSpecialize(res2, manual_graph, result.context);
auto patterns = std::vector<opt::SubstitutionPtr>({irpass_lib_.inline_, irpass_lib_.switch_simplify_});
ASSERT_TRUE(CheckIsomorphic(new_basic_graph, new_manual_graph, patterns));
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
public:
UT::PyFuncGraphFetcher getPyFun;
opt::irpass::OptimizeIRPassLib irpass_lib_;
};
void TestParallelIf::SetUp() { UT::InitPythonPath(); }
void TestParallelIf::TearDown() {}
// Feature: Parallel if transformation
// Description: Check parallel if transformatin for test code with single if/else.
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
TEST_F(TestParallelIf, SimpleIf) { CheckParallelIfTransform("test_simple_if"); }
// Feature: Parallel if transformation
// Description: Check parallel if transformatin for test code with if-by-if.
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
TEST_F(TestParallelIf, IfByIf) { CheckParallelIfTransform("test_if_by_if"); }
// Feature: Parallel if transformation
// Description: Check parallel if transformatin for test code with if-in-if.
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
TEST_F(TestParallelIf, IfInIf) { CheckParallelIfTransform("test_if_in_if"); }
// Feature: Parallel if transformation
// Description: Check parallel if transformatin for test code with if-elif-else.
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
TEST_F(TestParallelIf, IfElifElse) { CheckParallelIfTransform("test_if_elif_else"); }
} // namespace parse
} // namespace mindspore

View File

@ -0,0 +1,222 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
file: parallel_if.py
"""
from mindspore.ops import functional as F
from mindspore._extends.parse import standard_method
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name)
# pylint: disable=unused-variable
# disable pylint unused variable for basic/manual which are used by decorator.
def test_simple_if(tag):
"""
Feature: Parallel if transformation
Description: simple if with single if/else
Expectation: funcgraph parsed and manual constructed should be isomorphic.
"""
fns = FnDict()
@fns
def basic(x, y):
if x > y:
x = x + y
else:
x = x - y
return x + x
@fns
def manual(x, y):
def after(a_x):
return a_x + a_x
def true_branch():
return x + y
def false_branch():
return x - y
cond = standard_method.bool_(x > y)
switch_node = F.switch(cond, true_branch, false_branch)
result = switch_node()
return after(result)
return fns[tag]
def test_if_by_if(tag):
"""
Feature: Parallel if transformation
Description: if/else after if/else
Expectation: funcgraph parsed and manual constructed should be isomorphic.
"""
fns = FnDict()
@fns
def basic(x, y):
if x > y:
x = x + y
else:
x = x - y
if x < y:
y = x * y
else:
y = x + y
return x + y
@fns
def manual(x, y):
# first if
def true_branch1():
return x + y
def false_branch1():
return x - y
cond1 = standard_method.bool_(x > y)
switch_node = F.switch(cond1, true_branch1, false_branch1)
result1 = switch_node()
cond2 = standard_method.bool_(result1 < y)
# second if
def true_branch2():
return result1 * y
def false_branch2():
return result1 + y
def after2(a_x, a_y):
return a_x + a_y
def after1():
switch_node = F.switch(cond2, true_branch2, false_branch2)
result2 = switch_node()
return after2(result1, result2)
return after1()
return fns[tag]
def test_if_in_if(tag):
"""
Feature: Parallel if transformation
Description: if/else in if
Expectation: funcgraph parsed and manual constructed should be isomorphic.
"""
fns = FnDict()
@fns
def basic(x, y):
if x >= y:
if x > y:
x = x + y
else:
x = x - y
else:
x = x * y
return x + y
@fns
def manual(x, y):
# inner if/else
def true_branch2():
return x + y
def false_branch2():
return x - y
def after2(a_x):
return a_x
# outer if/else
def after1(a_x):
return a_x + y
def true_branch1():
cond2 = standard_method.bool_(x > y)
switch_node = F.switch(cond2, true_branch2, false_branch2)
result2 = switch_node()
return after2(result2)
def false_branch1():
return x * y
cond1 = standard_method.bool_(x >= y)
switch_node = F.switch(cond1, true_branch1, false_branch1)
result1 = switch_node()
return after1(result1)
return fns[tag]
def test_if_elif_else(tag):
"""
Feature: Parallel if transformation
Description: if/elif/else which can be treated as if/else{if/else}.
Expectation: funcgraph parsed and manual constructed should be isomorphic.
"""
fns = FnDict()
@fns
def basic(x, y):
if x > y:
out = x + y
elif x == y:
out = x - y
else:
out = x * y
return out + out
@fns
def manual(x, y):
# elif/else part
def true_branch2():
return x - y
def false_branch2():
return x * y
def after2(out):
return out
# if part
def after1(out):
return out + out
def true_branch1():
return x + y
def false_branch1():
cond2 = standard_method.bool_(x == y)
switch_node = F.switch(cond2, true_branch2, false_branch2)
result2 = switch_node()
return after2(result2)
cond1 = standard_method.bool_(x > y)
switch_node = F.switch(cond1, true_branch1, false_branch1)
result1 = switch_node()
return after1(result1)
return fns[tag]