!26914 fix some error report and remove some useless code
Merge pull request !26914 from lianliguang/master
This commit is contained in:
commit
b9d56b9cce
|
@ -165,7 +165,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
|
|||
type_except = TypeError if type_mismatch else ValueError
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
|
||||
raise type_except(f'{arg_name} {prim_name} should be {arg_type.__name__} and must {rel_str}, '
|
||||
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
||||
|
||||
return arg_value
|
||||
|
@ -589,8 +589,8 @@ class Validator:
|
|||
num_types = len(valid_types)
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, '
|
||||
f'but got {arg_value} with type {type(arg_value).__name__}.')
|
||||
f'\'{type_names if num_types > 1 else type_names[0]}\', '
|
||||
f'but got \'{arg_value}\' with type \'{type(arg_value).__name__}\'.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
# `check_value_type('x', True, [bool, int])` will check pass
|
||||
|
|
|
@ -547,153 +547,5 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager
|
|||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// expand tuples in graph parameters
|
||||
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
|
||||
const std::vector<AnfNodePtr> ¶ms) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
for (const auto ¶m : params) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_abs = param->abstract();
|
||||
MS_EXCEPTION_IF_NULL(param_abs);
|
||||
|
||||
if (param_abs->isa<AbstractJTagged>()) {
|
||||
MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
|
||||
}
|
||||
|
||||
if (!param_abs->isa<AbstractTuple>()) {
|
||||
new_params.emplace_back(param);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_param;
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
|
||||
for (auto &elem : abs_tuple->elements()) {
|
||||
auto np = std::make_shared<Parameter>(func_graph);
|
||||
np->set_abstract(elem);
|
||||
new_param.emplace_back(np);
|
||||
}
|
||||
(void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
|
||||
auto new_tuple = func_graph->NewCNode(inputs);
|
||||
(void)mng->Replace(param, new_tuple);
|
||||
|
||||
auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
|
||||
(void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
|
||||
}
|
||||
return new_params;
|
||||
}
|
||||
|
||||
// expand tuples in graph applies
|
||||
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
for (const auto &input : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
||||
auto input_abs = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
|
||||
if (input_abs->isa<AbstractJTagged>()) {
|
||||
auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
|
||||
if (abstract_tag->element()->isa<AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
|
||||
}
|
||||
}
|
||||
|
||||
if (!input_abs->isa<AbstractTuple>()) {
|
||||
new_inputs.emplace_back(input);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t idx = 0;
|
||||
std::vector<AnfNodePtr> new_input;
|
||||
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
|
||||
for (auto &elem : abs_tuple->elements()) {
|
||||
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
|
||||
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx));
|
||||
constexpr size_t scalar_index = 2;
|
||||
c_node->input(scalar_index)->set_abstract(aptr);
|
||||
c_node->set_abstract(elem);
|
||||
new_input.emplace_back(c_node);
|
||||
idx++;
|
||||
}
|
||||
|
||||
auto expand_tuple = ExpandTuplesC(graph, new_input);
|
||||
(void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
|
||||
}
|
||||
|
||||
return new_inputs;
|
||||
}
|
||||
|
||||
// remove most uses of tuples from the graph parameters & apply inputs
|
||||
// tuples that are returned will be kept
|
||||
// tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
|
||||
// CNode("tuple_getitem", (a,b,c), 0)
|
||||
// CNode("tuple_getitem", (a,b,c), 1)
|
||||
// CNode("tuple_getitem", (a,b,c), 2)
|
||||
// tuples in Graph's parameters: AbstractTuple (a, b, c) -->
|
||||
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
|
||||
// cppcheck-suppress unusedFunction
|
||||
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(root);
|
||||
|
||||
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||
AnfNodeSet all_node = manager->all_nodes();
|
||||
for (auto &node : all_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &inputs = cnode->inputs();
|
||||
|
||||
// Bypass the first input in inputs as it's fn.
|
||||
if (!IsValueNode<Primitive>(inputs[0])) {
|
||||
std::vector<AnfNodePtr> expand_inputs;
|
||||
(void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
|
||||
|
||||
auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
|
||||
if (new_inputs != expand_inputs) {
|
||||
std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
|
||||
(void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
auto new_node = node->func_graph()->NewCNode(cnode_inputs);
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
||||
(void)manager->Replace(node, new_node);
|
||||
}
|
||||
// Bypass the first 2 inputs in inputs as it's [partial, fn].
|
||||
} else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
|
||||
std::vector<AnfNodePtr> expand_inputs;
|
||||
(void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
|
||||
|
||||
auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
|
||||
if (new_inputs != expand_inputs) {
|
||||
std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
|
||||
(void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
|
||||
new_node->set_abstract(cnode->abstract());
|
||||
|
||||
(void)manager->Replace(node, new_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphSet all_graph = manager->func_graphs();
|
||||
for (auto &func_graph : all_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
|
||||
manager->SetParameters(func_graph, expand_p);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,15 +29,9 @@
|
|||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
|
||||
// Remove the class type from graphs
|
||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
|
||||
// Remove most uses of tuples from the graph
|
||||
// tuples that are returned will be kept
|
||||
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "frontend/parallel/context.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -477,7 +478,7 @@ class PynativeEliminater : public OptimizerCaller {
|
|||
if (value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(DEBUG) << "Start FillZero Tensor";
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
|
||||
auto out_t = TensorConstructUtils::CreateZerosTensor(tensor->Dtype(), tensor->shape());
|
||||
char *data = reinterpret_cast<char *>(out_t->data_c());
|
||||
std::fill(data, data + out_t->data().nbytes(), 0);
|
||||
out = out_t;
|
||||
|
|
|
@ -151,11 +151,8 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
return BuildNewParameter(pattern, res, top_graph);
|
||||
} else if (pattern->isa<Imm>()) {
|
||||
return BuildImmNode(pattern);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
||||
|
|
|
@ -199,34 +199,5 @@ TEST_F(TestClean, TestEraseClassPartial) {
|
|||
auto manager = Manage(func_graph);
|
||||
SimplifyDataStructures(func_graph, manager);
|
||||
}
|
||||
|
||||
TEST_F(TestClean, TestEraseTuple) {
|
||||
ASSERT_TRUE(nullptr != me_graph);
|
||||
std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
|
||||
|
||||
int abstract_tuple_count = 0;
|
||||
|
||||
for (auto node : manager->all_nodes()) {
|
||||
auto dt = node->abstract();
|
||||
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
|
||||
abstract_tuple_count++;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(abstract_tuple_count, 4);
|
||||
|
||||
// erase tuple in CNode57 and Parameter
|
||||
EraseTuple(me_graph, manager);
|
||||
|
||||
abstract_tuple_count = 0;
|
||||
for (auto node : manager->all_nodes()) {
|
||||
auto dt = node->abstract();
|
||||
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
|
||||
abstract_tuple_count++;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(abstract_tuple_count, 3);
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue