From: @Margaret_wangrui
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-26 15:20:10 +08:00 committed by Gitee
commit e4c065b53f
10 changed files with 30 additions and 19 deletions

View File

@ -812,7 +812,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
auto switch_cnode = cnode_input->cast<CNodePtr>(); auto switch_cnode = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode); MS_EXCEPTION_IF_NULL(switch_cnode);
if (cnode->inputs().size() < 2) { constexpr size_t cnode_size = 2;
if (cnode->inputs().size() < cnode_size) {
cnode_inputs = switch_cnode->inputs(); cnode_inputs = switch_cnode->inputs();
return cnode_inputs; return cnode_inputs;
} }

View File

@ -951,7 +951,8 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
// slice a tuple // slice a tuple
// args: tuple, start index, end index, step // args: tuple, start index, end index, step
const std::string op_name("TupleSlice"); const std::string op_name("TupleSlice");
abstract::CheckArgsSize(op_name, args_spec_list, 2); constexpr size_t arg_size = 2;
abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0); AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1); AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);

View File

@ -342,7 +342,8 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: x_shape, axis // Inputs: x_shape, axis
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2); constexpr size_t arg_size = 2;
CheckArgsSize(op_name, args_spec_list, arg_size);
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(args_spec_list[1]); MS_EXCEPTION_IF_NULL(args_spec_list[1]);
@ -381,7 +382,8 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples. // Inputs: two tuples.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2); constexpr size_t arg_size = 2;
CheckArgsSize(op_name, args_spec_list, arg_size);
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString(); MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();

View File

@ -98,9 +98,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [getattr, data, attribute] // Inputs should be [getattr, data, attribute]
MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
constexpr size_t data_index = 1;
AnfNodePtr data = inputs[1]; constexpr size_t attribute_index = 2;
AnfNodePtr cons = inputs[2]; AnfNodePtr data = inputs[data_index];
AnfNodePtr cons = inputs[attribute_index];
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons); MS_EXCEPTION_IF_NULL(cons);
@ -140,9 +141,10 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
// Inputs should be [dict_getitem, dict, item] // Inputs should be [dict_getitem, dict, item]
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
constexpr size_t data_index = 1;
AnfNodePtr data = inputs[1]; constexpr size_t cons_index = 2;
AnfNodePtr cons = inputs[2]; AnfNodePtr data = inputs[data_index];
AnfNodePtr cons = inputs[cons_index];
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons); MS_EXCEPTION_IF_NULL(cons);
@ -334,7 +336,8 @@ AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
// Inputs should be [extract_keyword_arg, arg, key] // Inputs should be [extract_keyword_arg, arg, key]
MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
return inputs[2]; constexpr size_t key_index = 2;
return inputs[key_index];
} }
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
@ -575,7 +578,8 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const st
for (auto &elem : abs_tuple->elements()) { for (auto &elem : abs_tuple->elements()) {
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx)); AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx));
c_node->input(2)->set_abstract(aptr); constexpr size_t scalar_index = 2;
c_node->input(scalar_index)->set_abstract(aptr);
c_node->set_abstract(elem); c_node->set_abstract(elem);
new_input.emplace_back(c_node); new_input.emplace_back(c_node);
idx++; idx++;

View File

@ -99,7 +99,8 @@ void TwoCastEliminater::Visit(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimCast)) { if (IsPrimitiveCNode(node, prim::kPrimCast)) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
// {prim::kPrimCast, X, Y} // {prim::kPrimCast, X, Y}
if (cnode->size() != 3) { constexpr size_t cast_size = 3;
if (cnode->size() != cast_size) {
return; return;
} }
x_ = cnode->input(1); x_ = cnode->input(1);

View File

@ -261,7 +261,6 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
// 1. when this function is called, not all usage of this phi node had bound to the // 1. when this function is called, not all usage of this phi node had bound to the
// graph of this function block, some may stay in vars_ in other blocks. // graph of this function block, some may stay in vars_ in other blocks.
// 2. it's costly to iterate the graph to replace the phi for each phi. // 2. it's costly to iterate the graph to replace the phi for each phi.
// Args :
// phi : This parameter node is functioning as a phi node. // phi : This parameter node is functioning as a phi node.
bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi); MS_EXCEPTION_IF_NULL(phi);

View File

@ -411,7 +411,8 @@ FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::obje
LocationPtr Parser::GetLocation(const py::object &node) const { LocationPtr Parser::GetLocation(const py::object &node) const {
MS_EXCEPTION_IF_NULL(ast_); MS_EXCEPTION_IF_NULL(ast_);
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
if (ret.size() < 5) { constexpr size_t list_size = 5;
if (ret.size() < list_size) {
MS_LOG(EXCEPTION) << "List size should not be less than 5."; MS_LOG(EXCEPTION) << "List size should not be less than 5.";
} }
// Refer to Location::Location() for each member of ret: line, column, line_end, column_end. // Refer to Location::Location() for each member of ret: line, column, line_end, column_end.

View File

@ -189,7 +189,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
MS_LOG(INFO) << "Start new args and compile key:" << key; MS_LOG(INFO) << "Start new args and compile key:" << key;
g_args_cache[args_spec] = key++; g_args_cache[args_spec] = key++;
} }
auto argSpec = py::tuple(2); constexpr size_t arg_size = 2;
auto argSpec = py::tuple(arg_size);
argSpec[0] = name; argSpec[0] = name;
argSpec[1] = g_args_cache[args_spec]; argSpec[1] = g_args_cache[args_spec];
return argSpec; return argSpec;

View File

@ -260,7 +260,9 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
ScopeGuard scope_guard(scope); ScopeGuard scope_guard(scope);
FuncGraphPtr func_graph = out_conf->node()->func_graph(); FuncGraphPtr func_graph = out_conf->node()->func_graph();
AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); constexpr size_t source_node_index = 2;
AnfNodePtr new_node =
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
if (new_node->isa<CNode>()) { if (new_node->isa<CNode>()) {

View File

@ -402,8 +402,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
return BuildValueNode(real_func->prim(), abs); return BuildValueNode(real_func->prim(), abs);
} }
EvaluatorPtr eval; EvaluatorPtr eval = engine_->GetEvaluatorFor(func);
eval = engine_->GetEvaluatorFor(func);
MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(eval);
AbstractBasePtrList argvals = eval->NormalizeArgs(args); AbstractBasePtrList argvals = eval->NormalizeArgs(args);