add ut for tuple trans

This commit is contained in:
chenfei 2022-03-16 16:14:16 +08:00
parent 433d7335d8
commit 3872e29c3b
5 changed files with 390 additions and 59 deletions

View File

@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
} }
void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) { void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
MS_LOG(INFO) << "Set dump config:" << str;
static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = { static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = {
{kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}}; {kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}};
auto it = dump_level_map.find(str); auto it = dump_level_map.find(str);
@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
} }
} }
std::shared_ptr<OrderedSet<std::string>> GetAllConfigStrings(const std::string &config_full_string) {
size_t start_pos = 0;
auto config_strings = std::make_shared<OrderedSet<std::string>>();
// if '#' is the last char of str, the str is legal, so we use '<=' but not '<'.
while (start_pos <= config_full_string.size()) {
auto pos = config_full_string.find('#', start_pos);
if (pos == std::string::npos) {
pos = config_full_string.size();
}
auto substr = config_full_string.substr(start_pos, pos - start_pos);
// Skip the '#'
start_pos = pos + 1;
if (substr.empty()) {
continue;
}
(void)config_strings->insert(substr);
}
return config_strings;
}
bool ConfigsAreLegal(const std::shared_ptr<OrderedSet<std::string>> &config_strings) {
// Value 'int' is used to mark config group id
HashMap<std::string, int> config_white_list = {{kDumpConfigLineLevel0, 0},
{kDumpConfigLineLevel1, 0},
{kDumpConfigLineLevel2, 0},
{kDumpConfigDisableBackend, 1},
{kDumpConfigEnablePassIR, 2}};
// Key 'int' is config group id, value is the config.
HashMap<int, std::string> config_groups;
for (const auto &config_string : *config_strings) {
auto config_white_list_it = config_white_list.find(config_string);
if (config_white_list_it == config_white_list.end()) {
std::ostringstream buffer;
buffer << "Support configs:\n"
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
<< "[3]: " << kDumpConfigDisableBackend << "\n"
<< "[4]: " << kDumpConfigEnablePassIR;
MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str();
return false;
}
auto group_id = config_white_list_it->second;
// Check conflict configs.
auto config_groups_it = config_groups.find(group_id);
if (config_groups_it != config_groups.end()) {
const auto &record_config = config_groups_it->second;
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string
<< "].\n"
<< "Please keep only one of them.";
return false;
}
config_groups[group_id] = config_string;
}
return true;
}
DumpConfig GetDumpConfig() { DumpConfig GetDumpConfig() {
static std::vector<HashSet<std::string>> config_white_list = {
{kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2},
{kDumpConfigDisableBackend},
{kDumpConfigEnablePassIR}};
static DumpConfig dump_config = DumpConfig(); static DumpConfig dump_config = DumpConfig();
static bool parsed = false; static bool parsed = false;
if (parsed) { if (parsed) {
@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() {
parsed = true; parsed = true;
// Start parse config. // Start parse config.
std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG")); std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG"));
std::vector<std::shared_ptr<HashSet<std::string>>> configs = {std::make_shared<HashSet<std::string>>(),
std::make_shared<HashSet<std::string>>(),
std::make_shared<HashSet<std::string>>()};
auto constexpr max_string_len = 100; auto constexpr max_string_len = 100;
if (str.size() > max_string_len) { if (str.size() > max_string_len) {
MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len; MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len;
@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() {
if (str.empty()) { if (str.empty()) {
return dump_config; return dump_config;
} }
size_t start_pos = 0; auto config_strings = GetAllConfigStrings(str);
// if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'. if (!ConfigsAreLegal(config_strings)) {
while (start_pos <= str.size()) { return dump_config;
auto pos = str.find('#', start_pos);
if (pos == std::string::npos) {
pos = str.size();
}
auto substr = str.substr(start_pos, pos - start_pos);
start_pos = pos + 1;
bool is_illegal_config = true;
for (size_t i = 0; i < config_white_list.size(); i++) {
if (config_white_list[i].find(substr) != config_white_list[i].end()) {
is_illegal_config = false;
(void)configs[i]->insert(substr);
if (configs[i]->size() > 1) {
std::ostringstream buffer;
(void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) {
buffer << "\n" << config;
});
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n"
<< "Please keep only one of them.";
return dump_config;
}
}
}
if (is_illegal_config) {
std::ostringstream buffer;
buffer << "Support configs:\n"
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
<< "[3]: " << kDumpConfigDisableBackend << "\n"
<< "[4]: " << kDumpConfigEnablePassIR;
MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str();
return {};
}
} }
for (auto &config : configs) { for (const auto &config : *config_strings) {
SetDumpConfigByString(*config->begin(), &dump_config); SetDumpConfigByString(config, &dump_config);
} }
return dump_config; return dump_config;
} }

View File

@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor {
GraphTupleParamTransform graph_transform_; GraphTupleParamTransform graph_transform_;
}; };
// {,kPrimPartial, G, Tuple_Xs} // {PrimPartial, G, Tuple_Xs}
// => // =>
// {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}} // {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}}
// transform partial's tuple binding args to flat inputs. // transform partial's tuple binding args to flat inputs.
@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor {
auto partial = node->cast<CNodePtr>(); auto partial = node->cast<CNodePtr>();
const auto &partial_inputs = partial->inputs(); const auto &partial_inputs = partial->inputs();
const auto &fg = partial->func_graph(); const auto &fg = partial->func_graph();
// And primitive and function value node into args.
constexpr auto kPartialFirstArgIndex = 2; constexpr auto kPartialFirstArgIndex = 2;
auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex); // Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs.
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args); auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs);
if (change) { if (change) {
auto new_partial = fg->NewCNode(new_args); auto new_partial = fg->NewCNode(new_inputs);
new_partial->set_abstract(partial->abstract()); new_partial->set_abstract(partial->abstract());
return new_partial; return new_partial;
} }
@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor {
const auto &call_inputs = call_node->inputs(); const auto &call_inputs = call_node->inputs();
const auto &fg = call_node->func_graph(); const auto &fg = call_node->func_graph();
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
// Add function value node into args. // Put ValueNode<FuncGraph> into inputs.
auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1); auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
auto change = FlattenArgs(fg, call_inputs, 1, &new_args); auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs);
if (change) { if (change) {
auto new_call = fg->NewCNode(new_args); auto new_call = fg->NewCNode(new_inputs);
new_call->set_abstract(call_node->abstract()); new_call->set_abstract(call_node->abstract());
return new_call; return new_call;
} }

View File

@ -0,0 +1,99 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import numpy as np
import pytest
class MAPPOCriticNet(nn.Cell):
def __init__(self):
super().__init__()
self.linear1_actor = nn.Dense(54, # input local obs shape
64,
weight_init='XavierUniform',
# paper uses orthogonal with gain 5/3 for every dense123
has_bias=False,
activation=nn.Tanh())
def construct(self, x):
# Feature Extraction
x = self.linear1_actor(x)
return x
class MAPPOActor(nn.Cell):
def __init__(self, actor_net):
super().__init__()
self.actor_net = actor_net
def construct(self, inputs_data):
_, global_obs = inputs_data
out = self.actor_net(global_obs)
return out
class TestClass(nn.Cell):
def __init__(self, actor_list):
super().__init__()
self.zero = Tensor(0, ms.int32)
self.actor_list = actor_list
self.less = P.Less()
self.zeros = P.Zeros()
def train(self):
state = Tensor(np.random.random((3, 128, 18)), ms.float32)
init_global_obs = self.zeros((128, 54), ms.float32)
out = self.test(state, init_global_obs)
return out
@ms_function
def test(self, state, init_global_obs):
num_agent = self.zero
while self.less(num_agent, 3):
samples = (state[num_agent], init_global_obs)
self.actor_list[num_agent](samples)
num_agent += 1
return num_agent
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net():
"""
Feature: Tuple arg transform.
Description: Test the pass: transform tuple arg to tensor arg.
Expectation: Compile done without error.
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir")
actor_list = nn.CellList()
for _ in range(3):
net = MAPPOCriticNet()
actor = MAPPOActor(net)
actor_list.append(actor)
test = TestClass(actor_list)
graph_out = test.train()
assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,10 +22,12 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/arithmetic_simplify.h" #include "frontend/optimizer/irpass/arithmetic_simplify.h"
#include "pipeline/jit/action.h"
#include "debug/draw.h" #include "debug/draw.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common {
FuncGraphPairMapEquiv equiv_graph; FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node; NodeMapEquiv equiv_node;
irpass::OptimizeIRPassLib irpass_lib;
static const PrimitivePtr P; static const PrimitivePtr P;
static const PrimitivePtr Q; static const PrimitivePtr Q;
static const PrimitivePtr R; static const PrimitivePtr R;
@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common {
SubstitutionPtr elim_R; SubstitutionPtr elim_R;
SubstitutionPtr idempotent_P; SubstitutionPtr idempotent_P;
SubstitutionPtr Qct_to_P; SubstitutionPtr Qct_to_P;
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
}; };
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P"); const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) {
} }
TEST_F(TestOptOpt, ElimR) { TEST_F(TestOptOpt, ElimR) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1"); FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after");
ASSERT_TRUE(nullptr != before); ASSERT_TRUE(nullptr != before);
ASSERT_TRUE(nullptr != after); ASSERT_TRUE(nullptr != after);
@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) {
ASSERT_EQ(manager2->all_nodes().size(), 12); ASSERT_EQ(manager2->all_nodes().size(), 12);
} }
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
// Check tuple params and tuple args.
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
size_t tuple_arg_param_num = 0;
auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t {
auto abs = node->abstract();
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num;
};
for (const auto &node : all_nodes) {
// Count func graph call tuple args.
if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
auto call_node = node->cast<CNodePtr>();
tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(),
tuple_arg_param_num, tuple_accumulate_func);
}
// Count partial tuple args.
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
auto partial = node->cast<CNodePtr>();
constexpr auto kPartialFirstArgIdx = 2;
tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(),
tuple_arg_param_num, tuple_accumulate_func);
}
// Count tuple params.
if (IsValueNode<FuncGraph>(node)) {
auto fg = GetValueNode<FuncGraphPtr>(node);
tuple_arg_param_num =
std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func);
}
}
return tuple_arg_param_num;
}
// Feature: Switch call tuple arg transform.
// Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in
// the same time.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SwitchPartialTupleTrans) {
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
// Feature: Switch layer call tuple arg transform.
// Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's
// tensor arg in the same time.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) {
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
// Feature: Single graph call tuple arg transform.
// Description: Test single graph call's tuple arg transform.This case include tuple in tuple args.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) {
FuncGraphPtr test_graph =
getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -16,9 +16,11 @@
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
# pylint: disable=unused-variable # pylint: disable=unused-variable
@ -68,8 +70,12 @@ def test_add_zero(tag):
return fns[tag] return fns[tag]
def test_elimR(tag): def test_elim_r(tag):
""" test_elimR """ """
Feature: optimizer.
Description: test elimi R.
Expectation: run case with no exception.
"""
R = Primitive('R') R = Primitive('R')
fns = FnDict() fns = FnDict()
@ -495,6 +501,7 @@ def test_elim_transpose(tag):
return fns[tag] return fns[tag]
def test_elim_depend_value(tag): def test_elim_depend_value(tag):
""" test_elim_depend_value """ """ test_elim_depend_value """
fns = FnDict() fns = FnDict()
@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag):
return z return z
return fns[tag] return fns[tag]
# Test ut for file: call_graph_tuple_transform.h.
def test_tuple_flatten(tag):
"""
Feature: optimizer.
Description: test cases for pass: graph_tuple_transform.
Expectation: the tuple args and parameters are successfully flattened by the pass.
"""
fns = FnDict()
w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32))
x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
p = Tensor(3, mstype.float32)
out_channel = 64
kernel_size = 7
conv = P.Conv2D(out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1)
pow_ops = P.Pow()
@fns
def test_flatten_switch_partial_arg():
def called_graph_with_tuple(tuple_x, tuple_y):
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0),
F.tuple_getitem(tuple_y, 1))
# Add tuple args in partial args.
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
cond = x < y
switch_node = F.switch(cond, func1, func2)
# Add tuple args in call args.
return switch_node((pow_ops(x, p), pow_ops(w, p)))
index = Tensor(1, mstype.int32)
@fns
def test_flatten_switch_layer_partial_arg():
def called_graph_with_tuple(tuple_x):
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
def called_graph_no_tuple(param1, param2):
return conv(param1, param2)
# Add tuple args in partial
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
# Add tensor args in partial
func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p))
switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3))
return switch_node()
@fns
def test_flatten_simple_call_tuple_in_tuple_arg():
def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z):
result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
tuple_0 = F.tuple_getitem(tuple_tuple_y, 0)
result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1))
tensor_1 = F.tuple_getitem(tuple_tuple_y, 1)
result3 = conv(tensor_1, tensor_z)
return result1 + result2 + result3
# Tuple arg.
tuple_x_arg = (pow_ops(x, p), pow_ops(w, p))
# TupleTuple arg.
tuple_0_arg = (pow_ops(x, p), pow_ops(w, p))
tensor_1_arg = pow_ops(x, p)
tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg)
# TensorArg
tensor_z_arg = pow_ops(w, p)
return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg)
return fns[tag]