add ut for tuple trans

This commit is contained in:
chenfei 2022-03-16 16:14:16 +08:00
parent 433d7335d8
commit 3872e29c3b
5 changed files with 390 additions and 59 deletions

View File

@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
}
void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
MS_LOG(INFO) << "Set dump config:" << str;
static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = {
{kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}};
auto it = dump_level_map.find(str);
@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) {
}
}
std::shared_ptr<OrderedSet<std::string>> GetAllConfigStrings(const std::string &config_full_string) {
size_t start_pos = 0;
auto config_strings = std::make_shared<OrderedSet<std::string>>();
// if '#' is the last char of str, the str is legal, so we use '<=' but not '<'.
while (start_pos <= config_full_string.size()) {
auto pos = config_full_string.find('#', start_pos);
if (pos == std::string::npos) {
pos = config_full_string.size();
}
auto substr = config_full_string.substr(start_pos, pos - start_pos);
// Skip the '#'
start_pos = pos + 1;
if (substr.empty()) {
continue;
}
(void)config_strings->insert(substr);
}
return config_strings;
}
bool ConfigsAreLegal(const std::shared_ptr<OrderedSet<std::string>> &config_strings) {
// Value 'int' is used to mark config group id
HashMap<std::string, int> config_white_list = {{kDumpConfigLineLevel0, 0},
{kDumpConfigLineLevel1, 0},
{kDumpConfigLineLevel2, 0},
{kDumpConfigDisableBackend, 1},
{kDumpConfigEnablePassIR, 2}};
// Key 'int' is config group id, value is the config.
HashMap<int, std::string> config_groups;
for (const auto &config_string : *config_strings) {
auto config_white_list_it = config_white_list.find(config_string);
if (config_white_list_it == config_white_list.end()) {
std::ostringstream buffer;
buffer << "Support configs:\n"
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
<< "[3]: " << kDumpConfigDisableBackend << "\n"
<< "[4]: " << kDumpConfigEnablePassIR;
MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str();
return false;
}
auto group_id = config_white_list_it->second;
// Check conflict configs.
auto config_groups_it = config_groups.find(group_id);
if (config_groups_it != config_groups.end()) {
const auto &record_config = config_groups_it->second;
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string
<< "].\n"
<< "Please keep only one of them.";
return false;
}
config_groups[group_id] = config_string;
}
return true;
}
DumpConfig GetDumpConfig() {
static std::vector<HashSet<std::string>> config_white_list = {
{kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2},
{kDumpConfigDisableBackend},
{kDumpConfigEnablePassIR}};
static DumpConfig dump_config = DumpConfig();
static bool parsed = false;
if (parsed) {
@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() {
parsed = true;
// Start parse config.
std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG"));
std::vector<std::shared_ptr<HashSet<std::string>>> configs = {std::make_shared<HashSet<std::string>>(),
std::make_shared<HashSet<std::string>>(),
std::make_shared<HashSet<std::string>>()};
auto constexpr max_string_len = 100;
if (str.size() > max_string_len) {
MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len;
@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() {
if (str.empty()) {
return dump_config;
}
size_t start_pos = 0;
// if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'.
while (start_pos <= str.size()) {
auto pos = str.find('#', start_pos);
if (pos == std::string::npos) {
pos = str.size();
}
auto substr = str.substr(start_pos, pos - start_pos);
start_pos = pos + 1;
bool is_illegal_config = true;
for (size_t i = 0; i < config_white_list.size(); i++) {
if (config_white_list[i].find(substr) != config_white_list[i].end()) {
is_illegal_config = false;
(void)configs[i]->insert(substr);
if (configs[i]->size() > 1) {
std::ostringstream buffer;
(void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) {
buffer << "\n" << config;
});
MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n"
<< "Please keep only one of them.";
return dump_config;
}
}
}
if (is_illegal_config) {
std::ostringstream buffer;
buffer << "Support configs:\n"
<< "[0]: " << kDumpConfigLineLevel0 << "\n"
<< "[1]: " << kDumpConfigLineLevel1 << "\n"
<< "[2]: " << kDumpConfigLineLevel2 << "\n"
<< "[3]: " << kDumpConfigDisableBackend << "\n"
<< "[4]: " << kDumpConfigEnablePassIR;
MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str();
return {};
}
auto config_strings = GetAllConfigStrings(str);
if (!ConfigsAreLegal(config_strings)) {
return dump_config;
}
for (auto &config : configs) {
SetDumpConfigByString(*config->begin(), &dump_config);
for (const auto &config : *config_strings) {
SetDumpConfigByString(config, &dump_config);
}
return dump_config;
}

View File

@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor {
GraphTupleParamTransform graph_transform_;
};
// {,kPrimPartial, G, Tuple_Xs}
// {PrimPartial, G, Tuple_Xs}
// =>
// {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}}
// transform partial's tuple binding args to flat inputs.
@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor {
auto partial = node->cast<CNodePtr>();
const auto &partial_inputs = partial->inputs();
const auto &fg = partial->func_graph();
// And primitive and function value node into args.
constexpr auto kPartialFirstArgIndex = 2;
auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args);
// Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs.
auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs);
if (change) {
auto new_partial = fg->NewCNode(new_args);
auto new_partial = fg->NewCNode(new_inputs);
new_partial->set_abstract(partial->abstract());
return new_partial;
}
@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor {
const auto &call_inputs = call_node->inputs();
const auto &fg = call_node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
// Add function value node into args.
auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
auto change = FlattenArgs(fg, call_inputs, 1, &new_args);
// Put ValueNode<FuncGraph> into inputs.
auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs);
if (change) {
auto new_call = fg->NewCNode(new_args);
auto new_call = fg->NewCNode(new_inputs);
new_call->set_abstract(call_node->abstract());
return new_call;
}

View File

@ -0,0 +1,99 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import numpy as np
import pytest
class MAPPOCriticNet(nn.Cell):
def __init__(self):
super().__init__()
self.linear1_actor = nn.Dense(54, # input local obs shape
64,
weight_init='XavierUniform',
# paper uses orthogonal with gain 5/3 for every dense123
has_bias=False,
activation=nn.Tanh())
def construct(self, x):
# Feature Extraction
x = self.linear1_actor(x)
return x
class MAPPOActor(nn.Cell):
def __init__(self, actor_net):
super().__init__()
self.actor_net = actor_net
def construct(self, inputs_data):
_, global_obs = inputs_data
out = self.actor_net(global_obs)
return out
class TestClass(nn.Cell):
def __init__(self, actor_list):
super().__init__()
self.zero = Tensor(0, ms.int32)
self.actor_list = actor_list
self.less = P.Less()
self.zeros = P.Zeros()
def train(self):
state = Tensor(np.random.random((3, 128, 18)), ms.float32)
init_global_obs = self.zeros((128, 54), ms.float32)
out = self.test(state, init_global_obs)
return out
@ms_function
def test(self, state, init_global_obs):
num_agent = self.zero
while self.less(num_agent, 3):
samples = (state[num_agent], init_global_obs)
self.actor_list[num_agent](samples)
num_agent += 1
return num_agent
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net():
"""
Feature: Tuple arg transform.
Description: Test the pass: transform tuple arg to tensor arg.
Expectation: Compile done without error.
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir")
actor_list = nn.CellList()
for _ in range(3):
net = MAPPOCriticNet()
actor = MAPPOActor(net)
actor_list.append(actor)
test = TestClass(actor_list)
graph_out = test.train()
assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,10 +22,12 @@
#include "ir/anf.h"
#include "ir/visitor.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
#include "pipeline/jit/action.h"
#include "debug/draw.h"
#include "frontend/operator/ops.h"
@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common {
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
irpass::OptimizeIRPassLib irpass_lib;
static const PrimitivePtr P;
static const PrimitivePtr Q;
static const PrimitivePtr R;
@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common {
SubstitutionPtr elim_R;
SubstitutionPtr idempotent_P;
SubstitutionPtr Qct_to_P;
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
};
const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) {
}
TEST_F(TestOptOpt, ElimR) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after");
ASSERT_TRUE(nullptr != before);
ASSERT_TRUE(nullptr != after);
@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) {
ASSERT_EQ(manager2->all_nodes().size(), 12);
}
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
// Check tuple params and tuple args.
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
size_t tuple_arg_param_num = 0;
auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t {
auto abs = node->abstract();
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num;
};
for (const auto &node : all_nodes) {
// Count func graph call tuple args.
if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
auto call_node = node->cast<CNodePtr>();
tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(),
tuple_arg_param_num, tuple_accumulate_func);
}
// Count partial tuple args.
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
auto partial = node->cast<CNodePtr>();
constexpr auto kPartialFirstArgIdx = 2;
tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(),
tuple_arg_param_num, tuple_accumulate_func);
}
// Count tuple params.
if (IsValueNode<FuncGraph>(node)) {
auto fg = GetValueNode<FuncGraphPtr>(node);
tuple_arg_param_num =
std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func);
}
}
return tuple_arg_param_num;
}
// Feature: Switch call tuple arg transform.
// Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in
// the same time.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SwitchPartialTupleTrans) {
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
// Feature: Switch layer call tuple arg transform.
// Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's
// tensor arg in the same time.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) {
FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
// Feature: Single graph call tuple arg transform.
// Description: Test single graph call's tuple arg transform.This case include tuple in tuple args.
// Expectation: All tuple args are correctly transformed to tensor args.
TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) {
FuncGraphPtr test_graph =
getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg");
ASSERT_TRUE(nullptr != test_graph);
FuncGraphManagerPtr manager1 = Manage(test_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
std::vector<AbstractBasePtr> args_spec;
// Renormalize firstly.
auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0);
// Flatten tuple param and args.
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res);
SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten}));
transform(renormalized_fg, optimizer);
// Renormalize again.
auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec);
ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0);
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
}
} // namespace opt
} // namespace mindspore

View File

@ -16,9 +16,11 @@
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops import Primitive
from mindspore.ops import _constants as Constants
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G
# pylint: disable=unused-variable
@ -68,8 +70,12 @@ def test_add_zero(tag):
return fns[tag]
def test_elimR(tag):
""" test_elimR """
def test_elim_r(tag):
"""
Feature: optimizer.
Description: test elimi R.
Expectation: run case with no exception.
"""
R = Primitive('R')
fns = FnDict()
@ -495,6 +501,7 @@ def test_elim_transpose(tag):
return fns[tag]
def test_elim_depend_value(tag):
""" test_elim_depend_value """
fns = FnDict()
@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag):
return z
return fns[tag]
# Test ut for file: call_graph_tuple_transform.h.
def test_tuple_flatten(tag):
"""
Feature: optimizer.
Description: test cases for pass: graph_tuple_transform.
Expectation: the tuple args and parameters are successfully flattened by the pass.
"""
fns = FnDict()
w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32))
x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32))
p = Tensor(3, mstype.float32)
out_channel = 64
kernel_size = 7
conv = P.Conv2D(out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1)
pow_ops = P.Pow()
@fns
def test_flatten_switch_partial_arg():
def called_graph_with_tuple(tuple_x, tuple_y):
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0),
F.tuple_getitem(tuple_y, 1))
# Add tuple args in partial args.
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
cond = x < y
switch_node = F.switch(cond, func1, func2)
# Add tuple args in call args.
return switch_node((pow_ops(x, p), pow_ops(w, p)))
index = Tensor(1, mstype.int32)
@fns
def test_flatten_switch_layer_partial_arg():
def called_graph_with_tuple(tuple_x):
return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
def called_graph_no_tuple(param1, param2):
return conv(param1, param2)
# Add tuple args in partial
func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p)))
# Add tensor args in partial
func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p))
switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3))
return switch_node()
@fns
def test_flatten_simple_call_tuple_in_tuple_arg():
def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z):
result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1))
tuple_0 = F.tuple_getitem(tuple_tuple_y, 0)
result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1))
tensor_1 = F.tuple_getitem(tuple_tuple_y, 1)
result3 = conv(tensor_1, tensor_z)
return result1 + result2 + result3
# Tuple arg.
tuple_x_arg = (pow_ops(x, p), pow_ops(w, p))
# TupleTuple arg.
tuple_0_arg = (pow_ops(x, p), pow_ops(w, p))
tensor_1_arg = pow_ops(x, p)
tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg)
# TensorArg
tensor_z_arg = pow_ops(w, p)
return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg)
return fns[tag]