Clean code

This commit is contained in:
yujianfeng 2022-03-19 12:17:37 +08:00
parent 164896b7b8
commit e9b54f2466
10 changed files with 75 additions and 51 deletions

View File

@ -26,6 +26,7 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck"
"mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc" "containerOutOfBounds"
"mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind"
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
# MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -110,21 +110,31 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
return fg_->NewCNode(out_cnode_inputs);
}
namespace {
AbstractBasePtrList GetOutAxesAbstractElements(const AbstractBasePtr &out_axes_abstract,
size_t inputs_abstract_elements_size, bool is_out_axes_tuple) {
AbstractBasePtrList out_axes_abstract_elements;
if (!is_out_axes_tuple) {
return out_axes_abstract_elements;
}
abstract::AbstractTuplePtr out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
out_axes_abstract_elements = out_axes_abstract_tuple->elements();
if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
}
return out_axes_abstract_elements;
}
} // namespace
CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
const AnfNodePtr &axis_size,
const AbstractBasePtrList &inputs_abstract_elements,
const AbstractBasePtr &out_axes_abstract) const {
bool is_out_axes_tuple = out_axes_abstract->isa<abstract::AbstractTuple>();
abstract::AbstractTuplePtr out_axes_abstract_tuple = nullptr;
AbstractBasePtrList out_axes_abstract_elements;
auto inputs_abstract_elements_size = inputs_abstract_elements.size();
if (is_out_axes_tuple) {
out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
out_axes_abstract_elements = out_axes_abstract_tuple->elements();
if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
}
}
AbstractBasePtrList out_axes_abstract_elements =
GetOutAxesAbstractElements(out_axes_abstract, inputs_abstract_elements_size, is_out_axes_tuple);
std::vector<AnfNodePtr> vals_out_tuple_cnode_inputs;
(void)vals_out_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
constexpr size_t kEachInputsSize = 2;
@ -206,9 +216,9 @@ FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args
if (args_spec_list_size != kMetaFGInputSize) {
MS_LOG(EXCEPTION) << "The number of inputs to VmapMatchOutAxis should be 3, but got " << args_spec_list_size << ".";
}
auto inputs_abstract = args_spec_list[0];
auto out_axes_abstract = args_spec_list[1];
auto axis_size_abstract = args_spec_list[2];
auto inputs_abstract = args_spec_list[kIndex0];
auto out_axes_abstract = args_spec_list[kIndex1];
auto axis_size_abstract = args_spec_list[kIndex2];
if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
<< inputs_abstract->ToString() << ".";

View File

@ -317,6 +317,8 @@ class KPynativeCellImpl : public KPynativeCell {
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
// Set sens and weights parameter nodes by user input info
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
AbstractBasePtr GetGradInputsSpec(const std::vector<size_t> &grad_position, bool grad_inputs,
AnfNodePtrList *grad_inputs_list);
// Set return node according to grad flag
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs,
bool grad_weights);
@ -992,40 +994,46 @@ void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool ha
}
}
AbstractBasePtr KPynativeCellImpl::GetGradInputsSpec(const std::vector<size_t> &grad_position, bool grad_inputs,
AnfNodePtrList *grad_inputs_list) {
AbstractBasePtr grad_inputs_spec;
auto pos_size = grad_position.size();
if (!grad_inputs && pos_size <= 1) {
return grad_inputs_spec;
}
AbstractBasePtrList grad_inputs_abs_list;
std::vector<size_t> grad_list;
if (grad_inputs) {
grad_list.resize(cell_inputs_.size());
iota(grad_list.begin(), grad_list.end(), 0);
} else if (pos_size > 1) {
grad_list = grad_position;
}
for (size_t i = 0; i < grad_list.size(); ++i) {
if (grad_list[i] >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
}
auto input = cell_inputs_[grad_list[i]];
MS_EXCEPTION_IF_NULL(input);
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
// If input is not used in the network, just return zeros_like() as dout;
MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
auto dout = BuildZerosLikeNode(tape_, input);
grad_inputs_list->push_back(dout);
} else {
grad_inputs_list->push_back(input_adjoint_iter->second->RealDout());
}
grad_inputs_abs_list.push_back(grad_inputs_list->back()->abstract());
}
grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
return grad_inputs_spec;
}
void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
bool grad_inputs, bool grad_weights) {
AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtr grad_inputs_spec;
auto pos_size = grad_position.size();
if (grad_inputs || pos_size > 1) {
AbstractBasePtrList grad_inputs_abs_list;
std::vector<size_t> grad_list;
if (grad_inputs) {
grad_list.resize(cell_inputs_.size());
iota(grad_list.begin(), grad_list.end(), 0);
} else if (pos_size > 1) {
grad_list = grad_position;
}
for (size_t i = 0; i < grad_list.size(); ++i) {
if (grad_list[i] >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
}
auto input = cell_inputs_[grad_list[i]];
MS_EXCEPTION_IF_NULL(input);
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
// If input is not used in the network, just return zeros_like() as dout;
MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
auto dout = BuildZerosLikeNode(tape_, input);
grad_inputs_list.push_back(dout);
} else {
grad_inputs_list.push_back(input_adjoint_iter->second->RealDout());
}
grad_inputs_abs_list.push_back(grad_inputs_list.back()->abstract());
}
grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
}
AbstractBasePtr grad_inputs_spec = GetGradInputsSpec(grad_position, grad_inputs, &grad_inputs_list);
AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtr grad_weights_spec;
if (grad_weights) {
@ -1056,7 +1064,7 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
tape_output->set_abstract(
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
} else if (grad_inputs || (pos_size > 1)) {
} else if (grad_inputs || (grad_position.size() > 1)) {
tape_output = tape_->NewCNode(grad_inputs_list);
tape_output->set_abstract(grad_inputs_spec);
} else if (grad_weights) {

View File

@ -161,7 +161,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return SizeToLong(i);
}
}
return n_attrs;
return SizeToLong(n_attrs);
}
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,

View File

@ -114,6 +114,9 @@ void InsertSliceAllGatherNode(const std::vector<std::pair<std::shared_ptr<AnfNod
return;
}
auto group = groups[0];
if (group.GetDevNum() == 0) {
MS_LOG(EXCEPTION) << "The dev num of group should not be 0.";
}
if (out_shape_element[0] % group.GetDevNum() != 0) {
MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0]
<< " cannot be divisible by the repeated size: " << group.GetDevNum()

View File

@ -1443,7 +1443,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
MS_LOG(DEBUG) << "True_end or false_end will not call after_block, true_block: " << true_block->ToString()
<< ", true_end: " << true_end->ToString() << ", false_block: " << false_block->ToString()
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
ignored_latter_call_graphs_.insert(after_block);
(void)ignored_latter_call_graphs_.insert(after_block);
}
if (true_branch_graphs.second != nullptr && false_branch_graphs.second != nullptr) {
true_branch_graphs.first = block;

View File

@ -82,9 +82,11 @@ class LoopContext {
}
~LoopContext() {
try {
(void)loops_->pop();
loops_->pop();
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when pop. Error info " << e.what();
} catch (...) {
MS_LOG(ERROR) << "Throw exception when pop.";
}
loops_ = nullptr;
}

View File

@ -627,9 +627,9 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
<< -shape_len << "," << shape_len << ").";
}
*axis = *axis < 0 ? shape_len + *axis : *axis;
auto temp_axes_size = orig_shape[*axis];
auto temp_axes_size = orig_shape[IntToSize(*axis)];
if (*axis_size == -1) {
*axis_size = temp_axes_size;
*axis_size = LongToInt(temp_axes_size);
} else if (*axis_size != temp_axes_size) {
MS_LOG(EXCEPTION) << "The `axes_size` of each argument in the scope of `vmap` should be equal, but got "
<< *axis_size << " and " << temp_axes_size << ".";

View File

@ -80,7 +80,7 @@ AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list) {
// args: None.
CheckArgsSize(primitive->name(), args_spec_list, 0);
static AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
static const AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
return abs_env;
}

View File

@ -84,7 +84,7 @@ class OrderedSet {
OrderedSet &operator=(OrderedSet &&other) = default;
// insert an element to the OrderedSet after the given position.
std::pair<iterator, bool> insert(iterator pos, const element_type &e) {
std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
auto result = map_.emplace(e, ordered_data_.end());
if (result.second) {
result.first->second = ordered_data_.emplace(pos, e);