add ut for tuple trans
This commit is contained in:
parent
433d7335d8
commit
3872e29c3b
|
@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
|
void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
|
||||||
|
MS_LOG(INFO) << "Set dump config:" << str;
|
||||||
static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = {
|
static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = {
|
||||||
{kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}};
|
{kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}};
|
||||||
auto it = dump_level_map.find(str);
|
auto it = dump_level_map.find(str);
|
||||||
|
@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<OrderedSet<std::string>> GetAllConfigStrings(const std::string &config_full_string) {
|
||||||
|
size_t start_pos = 0;
|
||||||
|
auto config_strings = std::make_shared<OrderedSet<std::string>>();
|
||||||
|
// if '#' is the last char of str, the str is legal, so we use '<=' but not '<'.
|
||||||
|
while (start_pos <= config_full_string.size()) {
|
||||||
|
auto pos = config_full_string.find('#', start_pos);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
pos = config_full_string.size();
|
||||||
|
}
|
||||||
|
auto substr = config_full_string.substr(start_pos, pos - start_pos);
|
||||||
|
// Skip the '#'
|
||||||
|
start_pos = pos + 1;
|
||||||
|
if (substr.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(void)config_strings->insert(substr);
|
||||||
|
}
|
||||||
|
return config_strings;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ConfigsAreLegal(const std::shared_ptr<OrderedSet<std::string>> &config_strings) {
|
||||||
|
// Value 'int' is used to mark config group id
|
||||||
|
HashMap<std::string, int> config_white_list = {{kDumpConfigLineLevel0, 0},
|
||||||
|
{kDumpConfigLineLevel1, 0},
|
||||||
|
{kDumpConfigLineLevel2, 0},
|
||||||
|
{kDumpConfigDisableBackend, 1},
|
||||||
|
{kDumpConfigEnablePassIR, 2}};
|
||||||
|
// Key 'int' is config group id, value is the config.
|
||||||
|
HashMap<int, std::string> config_groups;
|
||||||
|
for (const auto &config_string : *config_strings) {
|
||||||
|
auto config_white_list_it = config_white_list.find(config_string);
|
||||||
|
if (config_white_list_it == config_white_list.end()) {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "Support configs:\n"
|
||||||
|
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
|
||||||
|
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
|
||||||
|
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
|
||||||
|
<< "[3]: " << kDumpConfigDisableBackend << "\n"
|
||||||
|
<< "[4]: " << kDumpConfigEnablePassIR;
|
||||||
|
MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto group_id = config_white_list_it->second;
|
||||||
|
// Check conflict configs.
|
||||||
|
auto config_groups_it = config_groups.find(group_id);
|
||||||
|
if (config_groups_it != config_groups.end()) {
|
||||||
|
const auto &record_config = config_groups_it->second;
|
||||||
|
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string
|
||||||
|
<< "].\n"
|
||||||
|
<< "Please keep only one of them.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
config_groups[group_id] = config_string;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
DumpConfig GetDumpConfig() {
|
DumpConfig GetDumpConfig() {
|
||||||
static std::vector<HashSet<std::string>> config_white_list = {
|
|
||||||
{kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2},
|
|
||||||
{kDumpConfigDisableBackend},
|
|
||||||
{kDumpConfigEnablePassIR}};
|
|
||||||
static DumpConfig dump_config = DumpConfig();
|
static DumpConfig dump_config = DumpConfig();
|
||||||
static bool parsed = false;
|
static bool parsed = false;
|
||||||
if (parsed) {
|
if (parsed) {
|
||||||
|
@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() {
|
||||||
parsed = true;
|
parsed = true;
|
||||||
// Start parse config.
|
// Start parse config.
|
||||||
std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG"));
|
std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG"));
|
||||||
std::vector<std::shared_ptr<HashSet<std::string>>> configs = {std::make_shared<HashSet<std::string>>(),
|
|
||||||
std::make_shared<HashSet<std::string>>(),
|
|
||||||
std::make_shared<HashSet<std::string>>()};
|
|
||||||
auto constexpr max_string_len = 100;
|
auto constexpr max_string_len = 100;
|
||||||
if (str.size() > max_string_len) {
|
if (str.size() > max_string_len) {
|
||||||
MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len;
|
MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len;
|
||||||
|
@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() {
|
||||||
if (str.empty()) {
|
if (str.empty()) {
|
||||||
return dump_config;
|
return dump_config;
|
||||||
}
|
}
|
||||||
size_t start_pos = 0;
|
auto config_strings = GetAllConfigStrings(str);
|
||||||
// if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'.
|
if (!ConfigsAreLegal(config_strings)) {
|
||||||
while (start_pos <= str.size()) {
|
return dump_config;
|
||||||
auto pos = str.find('#', start_pos);
|
|
||||||
if (pos == std::string::npos) {
|
|
||||||
pos = str.size();
|
|
||||||
}
|
|
||||||
auto substr = str.substr(start_pos, pos - start_pos);
|
|
||||||
start_pos = pos + 1;
|
|
||||||
bool is_illegal_config = true;
|
|
||||||
for (size_t i = 0; i < config_white_list.size(); i++) {
|
|
||||||
if (config_white_list[i].find(substr) != config_white_list[i].end()) {
|
|
||||||
is_illegal_config = false;
|
|
||||||
(void)configs[i]->insert(substr);
|
|
||||||
if (configs[i]->size() > 1) {
|
|
||||||
std::ostringstream buffer;
|
|
||||||
(void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) {
|
|
||||||
buffer << "\n" << config;
|
|
||||||
});
|
|
||||||
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n"
|
|
||||||
<< "Please keep only one of them.";
|
|
||||||
return dump_config;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (is_illegal_config) {
|
|
||||||
std::ostringstream buffer;
|
|
||||||
buffer << "Support configs:\n"
|
|
||||||
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
|
|
||||||
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
|
|
||||||
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
|
|
||||||
<< "[3]: " << kDumpConfigDisableBackend << "\n"
|
|
||||||
<< "[4]: " << kDumpConfigEnablePassIR;
|
|
||||||
MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str();
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto &config : configs) {
|
for (const auto &config : *config_strings) {
|
||||||
SetDumpConfigByString(*config->begin(), &dump_config);
|
SetDumpConfigByString(config, &dump_config);
|
||||||
}
|
}
|
||||||
return dump_config;
|
return dump_config;
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor {
|
||||||
GraphTupleParamTransform graph_transform_;
|
GraphTupleParamTransform graph_transform_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// {,kPrimPartial, G, Tuple_Xs}
|
// {PrimPartial, G, Tuple_Xs}
|
||||||
// =>
|
// =>
|
||||||
// {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}}
|
// {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}}
|
||||||
// transform partial's tuple binding args to flat inputs.
|
// transform partial's tuple binding args to flat inputs.
|
||||||
|
@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor {
|
||||||
auto partial = node->cast<CNodePtr>();
|
auto partial = node->cast<CNodePtr>();
|
||||||
const auto &partial_inputs = partial->inputs();
|
const auto &partial_inputs = partial->inputs();
|
||||||
const auto &fg = partial->func_graph();
|
const auto &fg = partial->func_graph();
|
||||||
// And primitive and function value node into args.
|
|
||||||
constexpr auto kPartialFirstArgIndex = 2;
|
constexpr auto kPartialFirstArgIndex = 2;
|
||||||
auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
|
// Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs.
|
||||||
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args);
|
auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
|
||||||
|
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs);
|
||||||
if (change) {
|
if (change) {
|
||||||
auto new_partial = fg->NewCNode(new_args);
|
auto new_partial = fg->NewCNode(new_inputs);
|
||||||
new_partial->set_abstract(partial->abstract());
|
new_partial->set_abstract(partial->abstract());
|
||||||
return new_partial;
|
return new_partial;
|
||||||
}
|
}
|
||||||
|
@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor {
|
||||||
const auto &call_inputs = call_node->inputs();
|
const auto &call_inputs = call_node->inputs();
|
||||||
const auto &fg = call_node->func_graph();
|
const auto &fg = call_node->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
// Add function value node into args.
|
// Put ValueNode<FuncGraph> into inputs.
|
||||||
auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
|
auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
|
||||||
auto change = FlattenArgs(fg, call_inputs, 1, &new_args);
|
auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs);
|
||||||
if (change) {
|
if (change) {
|
||||||
auto new_call = fg->NewCNode(new_args);
|
auto new_call = fg->NewCNode(new_inputs);
|
||||||
new_call->set_abstract(call_node->abstract());
|
new_call->set_abstract(call_node->abstract());
|
||||||
return new_call;
|
return new_call;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class MAPPOCriticNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1_actor = nn.Dense(54, # input local obs shape
|
||||||
|
64,
|
||||||
|
weight_init='XavierUniform',
|
||||||
|
# paper uses orthogonal with gain 5/3 for every dense123
|
||||||
|
has_bias=False,
|
||||||
|
activation=nn.Tanh())
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
# Feature Extraction
|
||||||
|
x = self.linear1_actor(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MAPPOActor(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self, actor_net):
|
||||||
|
super().__init__()
|
||||||
|
self.actor_net = actor_net
|
||||||
|
|
||||||
|
def construct(self, inputs_data):
|
||||||
|
_, global_obs = inputs_data
|
||||||
|
out = self.actor_net(global_obs)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass(nn.Cell):
|
||||||
|
def __init__(self, actor_list):
|
||||||
|
super().__init__()
|
||||||
|
self.zero = Tensor(0, ms.int32)
|
||||||
|
self.actor_list = actor_list
|
||||||
|
self.less = P.Less()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
state = Tensor(np.random.random((3, 128, 18)), ms.float32)
|
||||||
|
init_global_obs = self.zeros((128, 54), ms.float32)
|
||||||
|
out = self.test(state, init_global_obs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def test(self, state, init_global_obs):
|
||||||
|
num_agent = self.zero
|
||||||
|
while self.less(num_agent, 3):
|
||||||
|
samples = (state[num_agent], init_global_obs)
|
||||||
|
self.actor_list[num_agent](samples)
|
||||||
|
num_agent += 1
|
||||||
|
|
||||||
|
return num_agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_net():
|
||||||
|
"""
|
||||||
|
Feature: Tuple arg transform.
|
||||||
|
Description: Test the pass: transform tuple arg to tensor arg.
|
||||||
|
Expectation: Compile done without error.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir")
|
||||||
|
actor_list = nn.CellList()
|
||||||
|
for _ in range(3):
|
||||||
|
net = MAPPOCriticNet()
|
||||||
|
actor = MAPPOActor(net)
|
||||||
|
actor_list.append(actor)
|
||||||
|
test = TestClass(actor_list)
|
||||||
|
graph_out = test.train()
|
||||||
|
|
||||||
|
assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001)
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,10 +22,12 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/visitor.h"
|
#include "ir/visitor.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "frontend/optimizer/optimizer.h"
|
||||||
#include "frontend/optimizer/opt.h"
|
#include "frontend/optimizer/opt.h"
|
||||||
#include "frontend/optimizer/anf_visitor.h"
|
#include "frontend/optimizer/anf_visitor.h"
|
||||||
#include "frontend/optimizer/irpass.h"
|
#include "frontend/optimizer/irpass.h"
|
||||||
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
|
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
|
||||||
|
#include "pipeline/jit/action.h"
|
||||||
|
|
||||||
#include "debug/draw.h"
|
#include "debug/draw.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common {
|
||||||
FuncGraphPairMapEquiv equiv_graph;
|
FuncGraphPairMapEquiv equiv_graph;
|
||||||
NodeMapEquiv equiv_node;
|
NodeMapEquiv equiv_node;
|
||||||
|
|
||||||
|
irpass::OptimizeIRPassLib irpass_lib;
|
||||||
|
|
||||||
static const PrimitivePtr P;
|
static const PrimitivePtr P;
|
||||||
static const PrimitivePtr Q;
|
static const PrimitivePtr Q;
|
||||||
static const PrimitivePtr R;
|
static const PrimitivePtr R;
|
||||||
|
@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common {
|
||||||
SubstitutionPtr elim_R;
|
SubstitutionPtr elim_R;
|
||||||
SubstitutionPtr idempotent_P;
|
SubstitutionPtr idempotent_P;
|
||||||
SubstitutionPtr Qct_to_P;
|
SubstitutionPtr Qct_to_P;
|
||||||
|
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
|
||||||
};
|
};
|
||||||
|
|
||||||
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
|
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
|
||||||
|
@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOptOpt, ElimR) {
|
TEST_F(TestOptOpt, ElimR) {
|
||||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1");
|
||||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after");
|
||||||
|
|
||||||
ASSERT_TRUE(nullptr != before);
|
ASSERT_TRUE(nullptr != before);
|
||||||
ASSERT_TRUE(nullptr != after);
|
ASSERT_TRUE(nullptr != after);
|
||||||
|
@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) {
|
||||||
ASSERT_EQ(manager2->all_nodes().size(), 12);
|
ASSERT_EQ(manager2->all_nodes().size(), 12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
|
||||||
|
// Check tuple params and tuple args.
|
||||||
|
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||||
|
size_t tuple_arg_param_num = 0;
|
||||||
|
auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t {
|
||||||
|
auto abs = node->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num;
|
||||||
|
};
|
||||||
|
for (const auto &node : all_nodes) {
|
||||||
|
// Count func graph call tuple args.
|
||||||
|
if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
|
||||||
|
auto call_node = node->cast<CNodePtr>();
|
||||||
|
tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(),
|
||||||
|
tuple_arg_param_num, tuple_accumulate_func);
|
||||||
|
}
|
||||||
|
// Count partial tuple args.
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||||
|
auto partial = node->cast<CNodePtr>();
|
||||||
|
constexpr auto kPartialFirstArgIdx = 2;
|
||||||
|
tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(),
|
||||||
|
tuple_arg_param_num, tuple_accumulate_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count tuple params.
|
||||||
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
|
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
|
tuple_arg_param_num =
|
||||||
|
std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tuple_arg_param_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feature: Switch call tuple arg transform.
|
||||||
|
// Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in
|
||||||
|
// the same time.
|
||||||
|
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||||
|
TEST_F(TestOptOpt, SwitchPartialTupleTrans) {
|
||||||
|
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg");
|
||||||
|
ASSERT_TRUE(nullptr != test_graph);
|
||||||
|
|
||||||
|
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||||
|
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||||
|
std::vector<AbstractBasePtr> args_spec;
|
||||||
|
|
||||||
|
// Renormalize firstly.
|
||||||
|
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||||
|
|
||||||
|
// Flatten tuple param and args.
|
||||||
|
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||||
|
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||||
|
transform(renormalized_fg, optimizer);
|
||||||
|
|
||||||
|
// Renormalize again.
|
||||||
|
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||||
|
|
||||||
|
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||||
|
abstract::AnalysisContext::ClearContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feature: Switch layer call tuple arg transform.
|
||||||
|
// Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's
|
||||||
|
// tensor arg in the same time.
|
||||||
|
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||||
|
TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) {
|
||||||
|
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg");
|
||||||
|
ASSERT_TRUE(nullptr != test_graph);
|
||||||
|
|
||||||
|
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||||
|
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||||
|
std::vector<AbstractBasePtr> args_spec;
|
||||||
|
|
||||||
|
// Renormalize firstly.
|
||||||
|
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||||
|
|
||||||
|
// Flatten tuple param and args.
|
||||||
|
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||||
|
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||||
|
transform(renormalized_fg, optimizer);
|
||||||
|
|
||||||
|
// Renormalize again.
|
||||||
|
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||||
|
|
||||||
|
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||||
|
abstract::AnalysisContext::ClearContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feature: Single graph call tuple arg transform.
|
||||||
|
// Description: Test single graph call's tuple arg transform.This case include tuple in tuple args.
|
||||||
|
// Expectation: All tuple args are correctly transformed to tensor args.
|
||||||
|
TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) {
|
||||||
|
FuncGraphPtr test_graph =
|
||||||
|
getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg");
|
||||||
|
ASSERT_TRUE(nullptr != test_graph);
|
||||||
|
|
||||||
|
FuncGraphManagerPtr manager1 = Manage(test_graph);
|
||||||
|
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||||
|
std::vector<AbstractBasePtr> args_spec;
|
||||||
|
|
||||||
|
// Renormalize firstly.
|
||||||
|
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
|
||||||
|
|
||||||
|
// Flatten tuple param and args.
|
||||||
|
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
|
||||||
|
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
|
||||||
|
transform(renormalized_fg, optimizer);
|
||||||
|
|
||||||
|
// Renormalize again.
|
||||||
|
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
|
||||||
|
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
|
||||||
|
|
||||||
|
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||||
|
abstract::AnalysisContext::ClearContext();
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
from mindspore import dtype as mstype
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import _constants as Constants
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
@ -68,8 +70,12 @@ def test_add_zero(tag):
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
def test_elimR(tag):
|
def test_elim_r(tag):
|
||||||
""" test_elimR """
|
"""
|
||||||
|
Feature: optimizer.
|
||||||
|
Description: test elimi R.
|
||||||
|
Expectation: run case with no exception.
|
||||||
|
"""
|
||||||
R = Primitive('R')
|
R = Primitive('R')
|
||||||
|
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
|
@ -495,6 +501,7 @@ def test_elim_transpose(tag):
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
def test_elim_depend_value(tag):
|
def test_elim_depend_value(tag):
|
||||||
""" test_elim_depend_value """
|
""" test_elim_depend_value """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
|
@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
# Test ut for file: call_graph_tuple_transform.h.
|
||||||
|
def test_tuple_flatten(tag):
|
||||||
|
"""
|
||||||
|
Feature: optimizer.
|
||||||
|
Description: test cases for pass: graph_tuple_transform.
|
||||||
|
Expectation: the tuple args and parameters are successfully flattened by the pass.
|
||||||
|
"""
|
||||||
|
fns = FnDict()
|
||||||
|
w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32))
|
||||||
|
x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
|
||||||
|
y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
|
||||||
|
|
||||||
|
p = Tensor(3, mstype.float32)
|
||||||
|
|
||||||
|
out_channel = 64
|
||||||
|
kernel_size = 7
|
||||||
|
conv = P.Conv2D(out_channel,
|
||||||
|
kernel_size,
|
||||||
|
mode=1,
|
||||||
|
pad_mode="valid",
|
||||||
|
pad=0,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
group=1)
|
||||||
|
pow_ops = P.Pow()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def test_flatten_switch_partial_arg():
|
||||||
|
def called_graph_with_tuple(tuple_x, tuple_y):
|
||||||
|
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0),
|
||||||
|
F.tuple_getitem(tuple_y, 1))
|
||||||
|
|
||||||
|
# Add tuple args in partial args.
|
||||||
|
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||||
|
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||||
|
cond = x < y
|
||||||
|
|
||||||
|
switch_node = F.switch(cond, func1, func2)
|
||||||
|
# Add tuple args in call args.
|
||||||
|
return switch_node((pow_ops(x, p), pow_ops(w, p)))
|
||||||
|
|
||||||
|
index = Tensor(1, mstype.int32)
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def test_flatten_switch_layer_partial_arg():
|
||||||
|
def called_graph_with_tuple(tuple_x):
|
||||||
|
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
|
||||||
|
|
||||||
|
def called_graph_no_tuple(param1, param2):
|
||||||
|
return conv(param1, param2)
|
||||||
|
|
||||||
|
# Add tuple args in partial
|
||||||
|
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||||
|
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
|
||||||
|
# Add tensor args in partial
|
||||||
|
func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p))
|
||||||
|
switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3))
|
||||||
|
return switch_node()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def test_flatten_simple_call_tuple_in_tuple_arg():
|
||||||
|
def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z):
|
||||||
|
result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
|
||||||
|
tuple_0 = F.tuple_getitem(tuple_tuple_y, 0)
|
||||||
|
result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1))
|
||||||
|
tensor_1 = F.tuple_getitem(tuple_tuple_y, 1)
|
||||||
|
result3 = conv(tensor_1, tensor_z)
|
||||||
|
return result1 + result2 + result3
|
||||||
|
|
||||||
|
# Tuple arg.
|
||||||
|
tuple_x_arg = (pow_ops(x, p), pow_ops(w, p))
|
||||||
|
# TupleTuple arg.
|
||||||
|
tuple_0_arg = (pow_ops(x, p), pow_ops(w, p))
|
||||||
|
tensor_1_arg = pow_ops(x, p)
|
||||||
|
tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg)
|
||||||
|
# TensorArg
|
||||||
|
tensor_z_arg = pow_ops(w, p)
|
||||||
|
return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg)
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue