!26914 fix some error report and remove some useless code

Merge pull request !26914 from lianliguang/master
This commit is contained in:
i-robot 2021-11-29 12:34:41 +00:00 committed by Gitee
commit b9d56b9cce
6 changed files with 6 additions and 191 deletions

View File

@ -165,7 +165,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
type_except = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
raise type_except(f'{arg_name} {prim_name} should be {arg_type.__name__} and must {rel_str}, '
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value
@ -589,8 +589,8 @@ class Validator:
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got {arg_value} with type {type(arg_value).__name__}.')
f'\'{type_names if num_types > 1 else type_names[0]}\', '
f'but got \'{arg_value}\' with type \'{type(arg_value).__name__}\'.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass

View File

@ -547,153 +547,5 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager
}
return changed;
}
// expand tuples in graph parameters
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &params) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_params;
for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param);
auto param_abs = param->abstract();
MS_EXCEPTION_IF_NULL(param_abs);
if (param_abs->isa<AbstractJTagged>()) {
MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
}
if (!param_abs->isa<AbstractTuple>()) {
new_params.emplace_back(param);
continue;
}
std::vector<AnfNodePtr> new_param;
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
for (auto &elem : abs_tuple->elements()) {
auto np = std::make_shared<Parameter>(func_graph);
np->set_abstract(elem);
new_param.emplace_back(np);
}
(void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
auto new_tuple = func_graph->NewCNode(inputs);
(void)mng->Replace(param, new_tuple);
auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
(void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
}
return new_params;
}
// expand tuples in graph applies
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> new_inputs;
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
auto input_abs = input->abstract();
MS_EXCEPTION_IF_NULL(input_abs);
if (input_abs->isa<AbstractJTagged>()) {
auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
if (abstract_tag->element()->isa<AbstractTuple>()) {
MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
}
}
if (!input_abs->isa<AbstractTuple>()) {
new_inputs.emplace_back(input);
continue;
}
int64_t idx = 0;
std::vector<AnfNodePtr> new_input;
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
for (auto &elem : abs_tuple->elements()) {
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx));
constexpr size_t scalar_index = 2;
c_node->input(scalar_index)->set_abstract(aptr);
c_node->set_abstract(elem);
new_input.emplace_back(c_node);
idx++;
}
auto expand_tuple = ExpandTuplesC(graph, new_input);
(void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
}
return new_inputs;
}
// remove most uses of tuples from the graph parameters & apply inputs
// tuples that are returned will be kept
// tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
// CNode("tuple_getitem", (a,b,c), 0)
// CNode("tuple_getitem", (a,b,c), 1)
// CNode("tuple_getitem", (a,b,c), 2)
// tuples in Graph's parameters: AbstractTuple (a, b, c) -->
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
// cppcheck-suppress unusedFunction
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes();
for (auto &node : all_node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
const auto &inputs = cnode->inputs();
// Bypass the first input in inputs as it's fn.
if (!IsValueNode<Primitive>(inputs[0])) {
std::vector<AnfNodePtr> expand_inputs;
(void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
if (new_inputs != expand_inputs) {
std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
(void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
MS_EXCEPTION_IF_NULL(node->func_graph());
auto new_node = node->func_graph()->NewCNode(cnode_inputs);
new_node->set_abstract(node->abstract());
(void)manager->Replace(node, new_node);
}
// Bypass the first 2 inputs in inputs as it's [partial, fn].
} else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
std::vector<AnfNodePtr> expand_inputs;
(void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
if (new_inputs != expand_inputs) {
std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
(void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
MS_EXCEPTION_IF_NULL(cnode->func_graph());
auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
new_node->set_abstract(cnode->abstract());
(void)manager->Replace(node, new_node);
}
}
}
FuncGraphSet all_graph = manager->func_graphs();
for (auto &func_graph : all_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
manager->SetParameters(func_graph, expand_p);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -29,15 +29,9 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
// Remove the class type from graphs
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
// Remove most uses of tuples from the graph
// tuples that are returned will be kept
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
} // namespace opt
} // namespace mindspore

View File

@ -34,6 +34,7 @@
#include "frontend/parallel/context.h"
#include "pipeline/jit/parse/resolve.h"
#include "frontend/parallel/step_parallel.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace opt {
@ -477,7 +478,7 @@ class PynativeEliminater : public OptimizerCaller {
if (value->isa<tensor::Tensor>()) {
MS_LOG(DEBUG) << "Start FillZero Tensor";
auto tensor = value->cast<tensor::TensorPtr>();
tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
auto out_t = TensorConstructUtils::CreateZerosTensor(tensor->Dtype(), tensor->shape());
char *data = reinterpret_cast<char *>(out_t->data_c());
std::fill(data, data + out_t->data().nbytes(), 0);
out = out_t;

View File

@ -151,11 +151,8 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
return BuildNewParameter(pattern, res, top_graph);
} else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern);
} else {
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
return nullptr;
}
return nullptr;
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
}
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,

View File

@ -199,34 +199,5 @@ TEST_F(TestClean, TestEraseClassPartial) {
auto manager = Manage(func_graph);
SimplifyDataStructures(func_graph, manager);
}
TEST_F(TestClean, TestEraseTuple) {
ASSERT_TRUE(nullptr != me_graph);
std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
int abstract_tuple_count = 0;
for (auto node : manager->all_nodes()) {
auto dt = node->abstract();
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
abstract_tuple_count++;
}
}
ASSERT_EQ(abstract_tuple_count, 4);
// erase tuple in CNode57 and Parameter
EraseTuple(me_graph, manager);
abstract_tuple_count = 0;
for (auto node : manager->all_nodes()) {
auto dt = node->abstract();
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
abstract_tuple_count++;
}
}
ASSERT_EQ(abstract_tuple_count, 3);
}
} // namespace opt
} // namespace mindspore