forked from mindspore-Ecosystem/mindspore
add ut test cases for parallel-if by check isomorphic funcgraphs
This commit is contained in:
parent
ccdcaae2a2
commit
6afda60609
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
class TestParallelIf : public UT::Common {
|
||||
public:
|
||||
TestParallelIf() : getPyFun("gtest_input.pipeline.parse.parallel_if") {}
|
||||
virtual void SetUp();
|
||||
virtual void TearDown();
|
||||
py::function GetPythonFunction(std::string function);
|
||||
|
||||
bool CheckIsomorphic(FuncGraphPtr basic, FuncGraphPtr manual, std::vector<opt::SubstitutionPtr> opts = {}) {
|
||||
opt::SubstitutionList transform(opts);
|
||||
FuncGraphPairMapEquiv equiv_graph;
|
||||
NodeMapEquiv equiv_node;
|
||||
|
||||
opt::OptimizerPtr optimizer = std::make_shared<opt::Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
|
||||
FuncGraphPtr basic_clone = BasicClone(basic);
|
||||
transform(basic_clone, optimizer);
|
||||
FuncGraphPtr manual_clone = BasicClone(manual);
|
||||
transform(manual_clone, optimizer);
|
||||
|
||||
return Isomorphic(basic_clone, manual_clone, &equiv_graph, &equiv_node);
|
||||
}
|
||||
|
||||
void CheckParallelIfTransform(const std::string &test_case) {
|
||||
FuncGraphPtr basic_graph = getPyFun.CallAndParseRet(test_case, "basic");
|
||||
ASSERT_TRUE(basic_graph != nullptr);
|
||||
FuncGraphPtr manual_graph = getPyFun.CallAndParseRet(test_case, "manual");
|
||||
ASSERT_TRUE(manual_graph != nullptr);
|
||||
|
||||
pipeline::ResourcePtr res1 = std::make_shared<pipeline::Resource>();
|
||||
|
||||
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1});
|
||||
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1});
|
||||
|
||||
AbstractBasePtr abstract_x = abstract::FromValue(x_tensor, true);
|
||||
AbstractBasePtr abstract_y = abstract::FromValue(y_tensor, true);
|
||||
abstract::AbstractBasePtrList args_spec_list{abstract_x, abstract_y};
|
||||
|
||||
abstract::AnalysisResult result = pipeline::AbstractAnalyze(res1, basic_graph, args_spec_list);
|
||||
auto new_basic_graph = pipeline::ProgramSpecialize(res1, basic_graph, result.context);
|
||||
|
||||
pipeline::ResourcePtr res2 = std::make_shared<pipeline::Resource>();
|
||||
result = pipeline::AbstractAnalyze(res2, manual_graph, args_spec_list);
|
||||
auto new_manual_graph = pipeline::ProgramSpecialize(res2, manual_graph, result.context);
|
||||
|
||||
auto patterns = std::vector<opt::SubstitutionPtr>({irpass_lib_.inline_, irpass_lib_.switch_simplify_});
|
||||
ASSERT_TRUE(CheckIsomorphic(new_basic_graph, new_manual_graph, patterns));
|
||||
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
}
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun;
|
||||
opt::irpass::OptimizeIRPassLib irpass_lib_;
|
||||
};
|
||||
|
||||
void TestParallelIf::SetUp() { UT::InitPythonPath(); }
|
||||
|
||||
void TestParallelIf::TearDown() {}
|
||||
|
||||
// Feature: Parallel if transformation
|
||||
// Description: Check parallel if transformatin for test code with single if/else.
|
||||
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
|
||||
TEST_F(TestParallelIf, SimpleIf) { CheckParallelIfTransform("test_simple_if"); }
|
||||
|
||||
// Feature: Parallel if transformation
|
||||
// Description: Check parallel if transformatin for test code with if-by-if.
|
||||
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
|
||||
TEST_F(TestParallelIf, IfByIf) { CheckParallelIfTransform("test_if_by_if"); }
|
||||
|
||||
// Feature: Parallel if transformation
|
||||
// Description: Check parallel if transformatin for test code with if-in-if.
|
||||
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
|
||||
TEST_F(TestParallelIf, IfInIf) { CheckParallelIfTransform("test_if_in_if"); }
|
||||
|
||||
// Feature: Parallel if transformation
|
||||
// Description: Check parallel if transformatin for test code with if-elif-else.
|
||||
// Expectation: The funcgraph after transformation should be isomorphic with the funcgraph manually constructed.
|
||||
TEST_F(TestParallelIf, IfElifElse) { CheckParallelIfTransform("test_if_elif_else"); }
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
file: parallel_if.py
|
||||
"""
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._extends.parse import standard_method
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fn_dict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fn_dict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fn_dict.get(name)
|
||||
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
# disable pylint unused variable for basic/manual which are used by decorator.
|
||||
def test_simple_if(tag):
|
||||
"""
|
||||
Feature: Parallel if transformation
|
||||
Description: simple if with single if/else
|
||||
Expectation: funcgraph parsed and manual constructed should be isomorphic.
|
||||
"""
|
||||
fns = FnDict()
|
||||
@fns
|
||||
def basic(x, y):
|
||||
if x > y:
|
||||
x = x + y
|
||||
else:
|
||||
x = x - y
|
||||
return x + x
|
||||
|
||||
@fns
|
||||
def manual(x, y):
|
||||
def after(a_x):
|
||||
return a_x + a_x
|
||||
|
||||
def true_branch():
|
||||
return x + y
|
||||
|
||||
def false_branch():
|
||||
return x - y
|
||||
|
||||
cond = standard_method.bool_(x > y)
|
||||
|
||||
switch_node = F.switch(cond, true_branch, false_branch)
|
||||
result = switch_node()
|
||||
return after(result)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_if_by_if(tag):
|
||||
"""
|
||||
Feature: Parallel if transformation
|
||||
Description: if/else after if/else
|
||||
Expectation: funcgraph parsed and manual constructed should be isomorphic.
|
||||
"""
|
||||
fns = FnDict()
|
||||
@fns
|
||||
def basic(x, y):
|
||||
if x > y:
|
||||
x = x + y
|
||||
else:
|
||||
x = x - y
|
||||
if x < y:
|
||||
y = x * y
|
||||
else:
|
||||
y = x + y
|
||||
return x + y
|
||||
|
||||
@fns
|
||||
def manual(x, y):
|
||||
# first if
|
||||
def true_branch1():
|
||||
return x + y
|
||||
|
||||
def false_branch1():
|
||||
return x - y
|
||||
|
||||
cond1 = standard_method.bool_(x > y)
|
||||
switch_node = F.switch(cond1, true_branch1, false_branch1)
|
||||
result1 = switch_node()
|
||||
|
||||
cond2 = standard_method.bool_(result1 < y)
|
||||
|
||||
# second if
|
||||
def true_branch2():
|
||||
return result1 * y
|
||||
|
||||
def false_branch2():
|
||||
return result1 + y
|
||||
|
||||
def after2(a_x, a_y):
|
||||
return a_x + a_y
|
||||
|
||||
def after1():
|
||||
switch_node = F.switch(cond2, true_branch2, false_branch2)
|
||||
result2 = switch_node()
|
||||
return after2(result1, result2)
|
||||
|
||||
return after1()
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_if_in_if(tag):
|
||||
"""
|
||||
Feature: Parallel if transformation
|
||||
Description: if/else in if
|
||||
Expectation: funcgraph parsed and manual constructed should be isomorphic.
|
||||
"""
|
||||
fns = FnDict()
|
||||
@fns
|
||||
def basic(x, y):
|
||||
if x >= y:
|
||||
if x > y:
|
||||
x = x + y
|
||||
else:
|
||||
x = x - y
|
||||
else:
|
||||
x = x * y
|
||||
return x + y
|
||||
|
||||
@fns
|
||||
def manual(x, y):
|
||||
# inner if/else
|
||||
def true_branch2():
|
||||
return x + y
|
||||
|
||||
def false_branch2():
|
||||
return x - y
|
||||
|
||||
def after2(a_x):
|
||||
return a_x
|
||||
|
||||
# outer if/else
|
||||
def after1(a_x):
|
||||
return a_x + y
|
||||
|
||||
def true_branch1():
|
||||
cond2 = standard_method.bool_(x > y)
|
||||
switch_node = F.switch(cond2, true_branch2, false_branch2)
|
||||
result2 = switch_node()
|
||||
return after2(result2)
|
||||
|
||||
def false_branch1():
|
||||
return x * y
|
||||
|
||||
cond1 = standard_method.bool_(x >= y)
|
||||
switch_node = F.switch(cond1, true_branch1, false_branch1)
|
||||
result1 = switch_node()
|
||||
return after1(result1)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_if_elif_else(tag):
|
||||
"""
|
||||
Feature: Parallel if transformation
|
||||
Description: if/elif/else which can be treated as if/else{if/else}.
|
||||
Expectation: funcgraph parsed and manual constructed should be isomorphic.
|
||||
"""
|
||||
fns = FnDict()
|
||||
@fns
|
||||
def basic(x, y):
|
||||
if x > y:
|
||||
out = x + y
|
||||
elif x == y:
|
||||
out = x - y
|
||||
else:
|
||||
out = x * y
|
||||
return out + out
|
||||
|
||||
@fns
|
||||
def manual(x, y):
|
||||
# elif/else part
|
||||
def true_branch2():
|
||||
return x - y
|
||||
|
||||
def false_branch2():
|
||||
return x * y
|
||||
|
||||
def after2(out):
|
||||
return out
|
||||
|
||||
# if part
|
||||
def after1(out):
|
||||
return out + out
|
||||
|
||||
def true_branch1():
|
||||
return x + y
|
||||
|
||||
def false_branch1():
|
||||
cond2 = standard_method.bool_(x == y)
|
||||
switch_node = F.switch(cond2, true_branch2, false_branch2)
|
||||
result2 = switch_node()
|
||||
return after2(result2)
|
||||
|
||||
cond1 = standard_method.bool_(x > y)
|
||||
switch_node = F.switch(cond1, true_branch1, false_branch1)
|
||||
result1 = switch_node()
|
||||
return after1(result1)
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue