forked from mindspore-Ecosystem/mindspore
add ut for tuple trans
This commit is contained in:
parent
433d7335d8
commit
3872e29c3b
|
@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
|
|||
}
|
||||
|
||||
void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
|
||||
MS_LOG(INFO) << "Set dump config:" << str;
|
||||
static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = {
|
||||
{kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}};
|
||||
auto it = dump_level_map.find(str);
|
||||
|
@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<OrderedSet<std::string>> GetAllConfigStrings(const std::string &config_full_string) {
|
||||
size_t start_pos = 0;
|
||||
auto config_strings = std::make_shared<OrderedSet<std::string>>();
|
||||
// if '#' is the last char of str, the str is legal, so we use '<=' but not '<'.
|
||||
while (start_pos <= config_full_string.size()) {
|
||||
auto pos = config_full_string.find('#', start_pos);
|
||||
if (pos == std::string::npos) {
|
||||
pos = config_full_string.size();
|
||||
}
|
||||
auto substr = config_full_string.substr(start_pos, pos - start_pos);
|
||||
// Skip the '#'
|
||||
start_pos = pos + 1;
|
||||
if (substr.empty()) {
|
||||
continue;
|
||||
}
|
||||
(void)config_strings->insert(substr);
|
||||
}
|
||||
return config_strings;
|
||||
}
|
||||
|
||||
bool ConfigsAreLegal(const std::shared_ptr<OrderedSet<std::string>> &config_strings) {
|
||||
// Value 'int' is used to mark config group id
|
||||
HashMap<std::string, int> config_white_list = {{kDumpConfigLineLevel0, 0},
|
||||
{kDumpConfigLineLevel1, 0},
|
||||
{kDumpConfigLineLevel2, 0},
|
||||
{kDumpConfigDisableBackend, 1},
|
||||
{kDumpConfigEnablePassIR, 2}};
|
||||
// Key 'int' is config group id, value is the config.
|
||||
HashMap<int, std::string> config_groups;
|
||||
for (const auto &config_string : *config_strings) {
|
||||
auto config_white_list_it = config_white_list.find(config_string);
|
||||
if (config_white_list_it == config_white_list.end()) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Support configs:\n"
|
||||
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
|
||||
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
|
||||
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
|
||||
<< "[3]: " << kDumpConfigDisableBackend << "\n"
|
||||
<< "[4]: " << kDumpConfigEnablePassIR;
|
||||
MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str();
|
||||
return false;
|
||||
}
|
||||
auto group_id = config_white_list_it->second;
|
||||
// Check conflict configs.
|
||||
auto config_groups_it = config_groups.find(group_id);
|
||||
if (config_groups_it != config_groups.end()) {
|
||||
const auto &record_config = config_groups_it->second;
|
||||
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string
|
||||
<< "].\n"
|
||||
<< "Please keep only one of them.";
|
||||
return false;
|
||||
}
|
||||
config_groups[group_id] = config_string;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
DumpConfig GetDumpConfig() {
|
||||
static std::vector<HashSet<std::string>> config_white_list = {
|
||||
{kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2},
|
||||
{kDumpConfigDisableBackend},
|
||||
{kDumpConfigEnablePassIR}};
|
||||
static DumpConfig dump_config = DumpConfig();
|
||||
static bool parsed = false;
|
||||
if (parsed) {
|
||||
|
@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() {
|
|||
parsed = true;
|
||||
// Start parse config.
|
||||
std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG"));
|
||||
std::vector<std::shared_ptr<HashSet<std::string>>> configs = {std::make_shared<HashSet<std::string>>(),
|
||||
std::make_shared<HashSet<std::string>>(),
|
||||
std::make_shared<HashSet<std::string>>()};
|
||||
auto constexpr max_string_len = 100;
|
||||
if (str.size() > max_string_len) {
|
||||
MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len;
|
||||
|
@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() {
|
|||
if (str.empty()) {
|
||||
return dump_config;
|
||||
}
|
||||
size_t start_pos = 0;
|
||||
// if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'.
|
||||
while (start_pos <= str.size()) {
|
||||
auto pos = str.find('#', start_pos);
|
||||
if (pos == std::string::npos) {
|
||||
pos = str.size();
|
||||
}
|
||||
auto substr = str.substr(start_pos, pos - start_pos);
|
||||
start_pos = pos + 1;
|
||||
bool is_illegal_config = true;
|
||||
for (size_t i = 0; i < config_white_list.size(); i++) {
|
||||
if (config_white_list[i].find(substr) != config_white_list[i].end()) {
|
||||
is_illegal_config = false;
|
||||
(void)configs[i]->insert(substr);
|
||||
if (configs[i]->size() > 1) {
|
||||
std::ostringstream buffer;
|
||||
(void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) {
|
||||
buffer << "\n" << config;
|
||||
});
|
||||
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n"
|
||||
<< "Please keep only one of them.";
|
||||
return dump_config;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_illegal_config) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Support configs:\n"
|
||||
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
|
||||
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
|
||||
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
|
||||
<< "[3]: " << kDumpConfigDisableBackend << "\n"
|
||||
<< "[4]: " << kDumpConfigEnablePassIR;
|
||||
MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str();
|
||||
return {};
|
||||
}
|
||||
auto config_strings = GetAllConfigStrings(str);
|
||||
if (!ConfigsAreLegal(config_strings)) {
|
||||
return dump_config;
|
||||
}
|
||||
for (auto &config : configs) {
|
||||
SetDumpConfigByString(*config->begin(), &dump_config);
|
||||
for (const auto &config : *config_strings) {
|
||||
SetDumpConfigByString(config, &dump_config);
|
||||
}
|
||||
return dump_config;
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor {
|
|||
GraphTupleParamTransform graph_transform_;
|
||||
};
|
||||
|
||||
// {,kPrimPartial, G, Tuple_Xs}
|
||||
// {PrimPartial, G, Tuple_Xs}
|
||||
// =>
|
||||
// {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}}
|
||||
// transform partial's tuple binding args to flat inputs.
|
||||
|
@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor {
|
|||
auto partial = node->cast<CNodePtr>();
|
||||
const auto &partial_inputs = partial->inputs();
|
||||
const auto &fg = partial->func_graph();
|
||||
// And primitive and function value node into args.
|
||||
constexpr auto kPartialFirstArgIndex = 2;
|
||||
auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
|
||||
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args);
|
||||
// Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs.
|
||||
auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
|
||||
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs);
|
||||
if (change) {
|
||||
auto new_partial = fg->NewCNode(new_args);
|
||||
auto new_partial = fg->NewCNode(new_inputs);
|
||||
new_partial->set_abstract(partial->abstract());
|
||||
return new_partial;
|
||||
}
|
||||
|
@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor {
|
|||
const auto &call_inputs = call_node->inputs();
|
||||
const auto &fg = call_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
// Add function value node into args.
|
||||
auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
|
||||
auto change = FlattenArgs(fg, call_inputs, 1, &new_args);
|
||||
// Put ValueNode<FuncGraph> into inputs.
|
||||
auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
|
||||
auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs);
|
||||
if (change) {
|
||||
auto new_call = fg->NewCNode(new_args);
|
||||
auto new_call = fg->NewCNode(new_inputs);
|
||||
new_call->set_abstract(call_node->abstract());
|
||||
return new_call;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.nn as nn
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class MAPPOCriticNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1_actor = nn.Dense(54, # input local obs shape
|
||||
64,
|
||||
weight_init='XavierUniform',
|
||||
# paper uses orthogonal with gain 5/3 for every dense123
|
||||
has_bias=False,
|
||||
activation=nn.Tanh())
|
||||
|
||||
def construct(self, x):
|
||||
# Feature Extraction
|
||||
x = self.linear1_actor(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MAPPOActor(nn.Cell):
|
||||
|
||||
def __init__(self, actor_net):
|
||||
super().__init__()
|
||||
self.actor_net = actor_net
|
||||
|
||||
def construct(self, inputs_data):
|
||||
_, global_obs = inputs_data
|
||||
out = self.actor_net(global_obs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TestClass(nn.Cell):
|
||||
def __init__(self, actor_list):
|
||||
super().__init__()
|
||||
self.zero = Tensor(0, ms.int32)
|
||||
self.actor_list = actor_list
|
||||
self.less = P.Less()
|
||||
self.zeros = P.Zeros()
|
||||
|
||||
def train(self):
|
||||
state = Tensor(np.random.random((3, 128, 18)), ms.float32)
|
||||
init_global_obs = self.zeros((128, 54), ms.float32)
|
||||
out = self.test(state, init_global_obs)
|
||||
return out
|
||||
|
||||
@ms_function
|
||||
def test(self, state, init_global_obs):
|
||||
num_agent = self.zero
|
||||
while self.less(num_agent, 3):
|
||||
samples = (state[num_agent], init_global_obs)
|
||||
self.actor_list[num_agent](samples)
|
||||
num_agent += 1
|
||||
|
||||
return num_agent
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net():
|
||||
"""
|
||||
Feature: Tuple arg transform.
|
||||
Description: Test the pass: transform tuple arg to tensor arg.
|
||||
Expectation: Compile done without error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir")
|
||||
actor_list = nn.CellList()
|
||||
for _ in range(3):
|
||||
net = MAPPOCriticNet()
|
||||
actor = MAPPOActor(net)
|
||||
actor_list.append(actor)
|
||||
test = TestClass(actor_list)
|
||||
graph_out = test.train()
|
||||
|
||||
assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001)
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,10 +22,12 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
|
||||
#include "debug/draw.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common {
|
|||
FuncGraphPairMapEquiv equiv_graph;
|
||||
NodeMapEquiv equiv_node;
|
||||
|
||||
irpass::OptimizeIRPassLib irpass_lib;
|
||||
|
||||
static const PrimitivePtr P;
|
||||
static const PrimitivePtr Q;
|
||||
static const PrimitivePtr R;
|
||||
|
@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common {
|
|||
SubstitutionPtr elim_R;
|
||||
SubstitutionPtr idempotent_P;
|
||||
SubstitutionPtr Qct_to_P;
|
||||
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
|
||||
};
|
||||
|
||||
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
|
||||
|
@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) {
|
|||
}
|
||||
|
||||
TEST_F(TestOptOpt, ElimR) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after");
|
||||
|
||||
ASSERT_TRUE(nullptr != before);
|
||||
ASSERT_TRUE(nullptr != after);
|
||||
|
@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) {
|
|||
ASSERT_EQ(manager2->all_nodes().size(), 12);
|
||||
}
|
||||
|
||||
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
|
||||
// Check tuple params and tuple args.
|
||||
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
size_t tuple_arg_param_num = 0;
|
||||
auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t {
|
||||
auto abs = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num;
|
||||
};
|
||||
for (const auto &node : all_nodes) {
|
||||
// Count func graph call tuple args.
|
||||
if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
|
||||
auto call_node = node->cast<CNodePtr>();
|
||||
tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(),
|
||||
tuple_arg_param_num, tuple_accumulate_func);
|
||||
}
|
||||
// Count partial tuple args.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||
auto partial = node->cast<CNodePtr>();
|
||||
constexpr auto kPartialFirstArgIdx = 2;
|
||||
tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(),
|
||||
tuple_arg_param_num, tuple_accumulate_func);
|
||||
}
|
||||
|
||||
// Count tuple params.
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
tuple_arg_param_num =
|
||||
std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func);
|
||||
}
|
||||
}
|
||||
return tuple_arg_param_num;
|
||||
}
|
||||
|
||||
// Feature: Switch call tuple arg transform.
|
||||
// Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in
|
||||
// the same time.
|
||||
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||
TEST_F(TestOptOpt, SwitchPartialTupleTrans) {
|
||||
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg");
|
||||
ASSERT_TRUE(nullptr != test_graph);
|
||||
|
||||
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
std::vector<AbstractBasePtr> args_spec;
|
||||
|
||||
// Renormalize firstly.
|
||||
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||
|
||||
// Flatten tuple param and args.
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||
transform(renormalized_fg, optimizer);
|
||||
|
||||
// Renormalize again.
|
||||
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
}
|
||||
|
||||
// Feature: Switch layer call tuple arg transform.
|
||||
// Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's
|
||||
// tensor arg in the same time.
|
||||
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||
TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) {
|
||||
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg");
|
||||
ASSERT_TRUE(nullptr != test_graph);
|
||||
|
||||
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
std::vector<AbstractBasePtr> args_spec;
|
||||
|
||||
// Renormalize firstly.
|
||||
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||
|
||||
// Flatten tuple param and args.
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||
transform(renormalized_fg, optimizer);
|
||||
|
||||
// Renormalize again.
|
||||
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
}
|
||||
|
||||
// Feature: Single graph call tuple arg transform.
|
||||
// Description: Test single graph call's tuple arg transform.This case include tuple in tuple args.
|
||||
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||
TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) {
|
||||
FuncGraphPtr test_graph =
|
||||
getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg");
|
||||
ASSERT_TRUE(nullptr != test_graph);
|
||||
|
||||
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
std::vector<AbstractBasePtr> args_spec;
|
||||
|
||||
// Renormalize firstly.
|
||||
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||
|
||||
// Flatten tuple param and args.
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||
transform(renormalized_fg, optimizer);
|
||||
|
||||
// Renormalize again.
|
||||
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import _constants as Constants
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
|
@ -68,8 +70,12 @@ def test_add_zero(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_elimR(tag):
|
||||
""" test_elimR """
|
||||
def test_elim_r(tag):
|
||||
"""
|
||||
Feature: optimizer.
|
||||
Description: test elimi R.
|
||||
Expectation: run case with no exception.
|
||||
"""
|
||||
R = Primitive('R')
|
||||
|
||||
fns = FnDict()
|
||||
|
@ -495,6 +501,7 @@ def test_elim_transpose(tag):
|
|||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_elim_depend_value(tag):
|
||||
""" test_elim_depend_value """
|
||||
fns = FnDict()
|
||||
|
@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag):
|
|||
return z
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
# Test ut for file: call_graph_tuple_transform.h.
|
||||
def test_tuple_flatten(tag):
|
||||
"""
|
||||
Feature: optimizer.
|
||||
Description: test cases for pass: graph_tuple_transform.
|
||||
Expectation: the tuple args and parameters are successfully flattened by the pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32))
|
||||
x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
|
||||
y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
|
||||
|
||||
p = Tensor(3, mstype.float32)
|
||||
|
||||
out_channel = 64
|
||||
kernel_size = 7
|
||||
conv = P.Conv2D(out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1)
|
||||
pow_ops = P.Pow()
|
||||
|
||||
@fns
|
||||
def test_flatten_switch_partial_arg():
|
||||
def called_graph_with_tuple(tuple_x, tuple_y):
|
||||
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0),
|
||||
F.tuple_getitem(tuple_y, 1))
|
||||
|
||||
# Add tuple args in partial args.
|
||||
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||
cond = x < y
|
||||
|
||||
switch_node = F.switch(cond, func1, func2)
|
||||
# Add tuple args in call args.
|
||||
return switch_node((pow_ops(x, p), pow_ops(w, p)))
|
||||
|
||||
index = Tensor(1, mstype.int32)
|
||||
|
||||
@fns
|
||||
def test_flatten_switch_layer_partial_arg():
|
||||
def called_graph_with_tuple(tuple_x):
|
||||
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
|
||||
|
||||
def called_graph_no_tuple(param1, param2):
|
||||
return conv(param1, param2)
|
||||
|
||||
# Add tuple args in partial
|
||||
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||
# Add tensor args in partial
|
||||
func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p))
|
||||
switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3))
|
||||
return switch_node()
|
||||
|
||||
@fns
|
||||
def test_flatten_simple_call_tuple_in_tuple_arg():
|
||||
def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z):
|
||||
result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
|
||||
tuple_0 = F.tuple_getitem(tuple_tuple_y, 0)
|
||||
result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1))
|
||||
tensor_1 = F.tuple_getitem(tuple_tuple_y, 1)
|
||||
result3 = conv(tensor_1, tensor_z)
|
||||
return result1 + result2 + result3
|
||||
|
||||
# Tuple arg.
|
||||
tuple_x_arg = (pow_ops(x, p), pow_ops(w, p))
|
||||
# TupleTuple arg.
|
||||
tuple_0_arg = (pow_ops(x, p), pow_ops(w, p))
|
||||
tensor_1_arg = pow_ops(x, p)
|
||||
tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg)
|
||||
# TensorArg
|
||||
tensor_z_arg = pow_ops(w, p)
|
||||
return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg)
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue