forked from mindspore-Ecosystem/mindspore
Clean code
This commit is contained in:
parent
164896b7b8
commit
e9b54f2466
|
@ -26,6 +26,7 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc" "containerOutOfBounds"
|
||||
"mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind"
|
||||
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -110,21 +110,31 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
|
|||
return fg_->NewCNode(out_cnode_inputs);
|
||||
}
|
||||
|
||||
namespace {
|
||||
AbstractBasePtrList GetOutAxesAbstractElements(const AbstractBasePtr &out_axes_abstract,
|
||||
size_t inputs_abstract_elements_size, bool is_out_axes_tuple) {
|
||||
AbstractBasePtrList out_axes_abstract_elements;
|
||||
if (!is_out_axes_tuple) {
|
||||
return out_axes_abstract_elements;
|
||||
}
|
||||
abstract::AbstractTuplePtr out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
|
||||
out_axes_abstract_elements = out_axes_abstract_tuple->elements();
|
||||
if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
|
||||
MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
|
||||
}
|
||||
return out_axes_abstract_elements;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inputs, const AnfNodePtr &out_axis,
|
||||
const AnfNodePtr &axis_size,
|
||||
const AbstractBasePtrList &inputs_abstract_elements,
|
||||
const AbstractBasePtr &out_axes_abstract) const {
|
||||
bool is_out_axes_tuple = out_axes_abstract->isa<abstract::AbstractTuple>();
|
||||
abstract::AbstractTuplePtr out_axes_abstract_tuple = nullptr;
|
||||
AbstractBasePtrList out_axes_abstract_elements;
|
||||
auto inputs_abstract_elements_size = inputs_abstract_elements.size();
|
||||
if (is_out_axes_tuple) {
|
||||
out_axes_abstract_tuple = dyn_cast<abstract::AbstractTuple>(out_axes_abstract);
|
||||
out_axes_abstract_elements = out_axes_abstract_tuple->elements();
|
||||
if (out_axes_abstract_elements.size() != inputs_abstract_elements_size) {
|
||||
MS_LOG(EXCEPTION) << "The length of out_axes and inputs do not match. ";
|
||||
}
|
||||
}
|
||||
AbstractBasePtrList out_axes_abstract_elements =
|
||||
GetOutAxesAbstractElements(out_axes_abstract, inputs_abstract_elements_size, is_out_axes_tuple);
|
||||
|
||||
std::vector<AnfNodePtr> vals_out_tuple_cnode_inputs;
|
||||
(void)vals_out_tuple_cnode_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
constexpr size_t kEachInputsSize = 2;
|
||||
|
@ -206,9 +216,9 @@ FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
if (args_spec_list_size != kMetaFGInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The number of inputs to VmapMatchOutAxis should be 3, but got " << args_spec_list_size << ".";
|
||||
}
|
||||
auto inputs_abstract = args_spec_list[0];
|
||||
auto out_axes_abstract = args_spec_list[1];
|
||||
auto axis_size_abstract = args_spec_list[2];
|
||||
auto inputs_abstract = args_spec_list[kIndex0];
|
||||
auto out_axes_abstract = args_spec_list[kIndex1];
|
||||
auto axis_size_abstract = args_spec_list[kIndex2];
|
||||
if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
|
||||
<< inputs_abstract->ToString() << ".";
|
||||
|
|
|
@ -317,6 +317,8 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
// Set sens and weights parameter nodes by user input info
|
||||
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
AbstractBasePtr GetGradInputsSpec(const std::vector<size_t> &grad_position, bool grad_inputs,
|
||||
AnfNodePtrList *grad_inputs_list);
|
||||
// Set return node according to grad flag
|
||||
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, bool grad_inputs,
|
||||
bool grad_weights);
|
||||
|
@ -992,40 +994,46 @@ void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool ha
|
|||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr KPynativeCellImpl::GetGradInputsSpec(const std::vector<size_t> &grad_position, bool grad_inputs,
|
||||
AnfNodePtrList *grad_inputs_list) {
|
||||
AbstractBasePtr grad_inputs_spec;
|
||||
auto pos_size = grad_position.size();
|
||||
if (!grad_inputs && pos_size <= 1) {
|
||||
return grad_inputs_spec;
|
||||
}
|
||||
AbstractBasePtrList grad_inputs_abs_list;
|
||||
std::vector<size_t> grad_list;
|
||||
if (grad_inputs) {
|
||||
grad_list.resize(cell_inputs_.size());
|
||||
iota(grad_list.begin(), grad_list.end(), 0);
|
||||
} else if (pos_size > 1) {
|
||||
grad_list = grad_position;
|
||||
}
|
||||
for (size_t i = 0; i < grad_list.size(); ++i) {
|
||||
if (grad_list[i] >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
|
||||
}
|
||||
auto input = cell_inputs_[grad_list[i]];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
|
||||
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
// If input is not used in the network, just return zeros_like() as dout;
|
||||
MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
|
||||
auto dout = BuildZerosLikeNode(tape_, input);
|
||||
grad_inputs_list->push_back(dout);
|
||||
} else {
|
||||
grad_inputs_list->push_back(input_adjoint_iter->second->RealDout());
|
||||
}
|
||||
grad_inputs_abs_list.push_back(grad_inputs_list->back()->abstract());
|
||||
}
|
||||
grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
|
||||
return grad_inputs_spec;
|
||||
}
|
||||
|
||||
void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
bool grad_inputs, bool grad_weights) {
|
||||
AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtr grad_inputs_spec;
|
||||
auto pos_size = grad_position.size();
|
||||
if (grad_inputs || pos_size > 1) {
|
||||
AbstractBasePtrList grad_inputs_abs_list;
|
||||
std::vector<size_t> grad_list;
|
||||
if (grad_inputs) {
|
||||
grad_list.resize(cell_inputs_.size());
|
||||
iota(grad_list.begin(), grad_list.end(), 0);
|
||||
} else if (pos_size > 1) {
|
||||
grad_list = grad_position;
|
||||
}
|
||||
for (size_t i = 0; i < grad_list.size(); ++i) {
|
||||
if (grad_list[i] >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
|
||||
}
|
||||
auto input = cell_inputs_[grad_list[i]];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
|
||||
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
// If input is not used in the network, just return zeros_like() as dout;
|
||||
MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
|
||||
auto dout = BuildZerosLikeNode(tape_, input);
|
||||
grad_inputs_list.push_back(dout);
|
||||
} else {
|
||||
grad_inputs_list.push_back(input_adjoint_iter->second->RealDout());
|
||||
}
|
||||
grad_inputs_abs_list.push_back(grad_inputs_list.back()->abstract());
|
||||
}
|
||||
grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr grad_inputs_spec = GetGradInputsSpec(grad_position, grad_inputs, &grad_inputs_list);
|
||||
AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtr grad_weights_spec;
|
||||
if (grad_weights) {
|
||||
|
@ -1056,7 +1064,7 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
|
|||
{NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
|
||||
tape_output->set_abstract(
|
||||
std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
|
||||
} else if (grad_inputs || (pos_size > 1)) {
|
||||
} else if (grad_inputs || (grad_position.size() > 1)) {
|
||||
tape_output = tape_->NewCNode(grad_inputs_list);
|
||||
tape_output->set_abstract(grad_inputs_spec);
|
||||
} else if (grad_weights) {
|
||||
|
|
|
@ -161,7 +161,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return SizeToLong(i);
|
||||
}
|
||||
}
|
||||
return n_attrs;
|
||||
return SizeToLong(n_attrs);
|
||||
}
|
||||
|
||||
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
|
||||
|
|
|
@ -114,6 +114,9 @@ void InsertSliceAllGatherNode(const std::vector<std::pair<std::shared_ptr<AnfNod
|
|||
return;
|
||||
}
|
||||
auto group = groups[0];
|
||||
if (group.GetDevNum() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The dev num of group should not be 0.";
|
||||
}
|
||||
if (out_shape_element[0] % group.GetDevNum() != 0) {
|
||||
MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0]
|
||||
<< " cannot be divisible by the repeated size: " << group.GetDevNum()
|
||||
|
|
|
@ -1443,7 +1443,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
MS_LOG(DEBUG) << "True_end or false_end will not call after_block, true_block: " << true_block->ToString()
|
||||
<< ", true_end: " << true_end->ToString() << ", false_block: " << false_block->ToString()
|
||||
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
|
||||
ignored_latter_call_graphs_.insert(after_block);
|
||||
(void)ignored_latter_call_graphs_.insert(after_block);
|
||||
}
|
||||
if (true_branch_graphs.second != nullptr && false_branch_graphs.second != nullptr) {
|
||||
true_branch_graphs.first = block;
|
||||
|
|
|
@ -82,9 +82,11 @@ class LoopContext {
|
|||
}
|
||||
~LoopContext() {
|
||||
try {
|
||||
(void)loops_->pop();
|
||||
loops_->pop();
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Exception when pop. Error info " << e.what();
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Throw exception when pop.";
|
||||
}
|
||||
loops_ = nullptr;
|
||||
}
|
||||
|
|
|
@ -627,9 +627,9 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
|
|||
<< -shape_len << "," << shape_len << ").";
|
||||
}
|
||||
*axis = *axis < 0 ? shape_len + *axis : *axis;
|
||||
auto temp_axes_size = orig_shape[*axis];
|
||||
auto temp_axes_size = orig_shape[IntToSize(*axis)];
|
||||
if (*axis_size == -1) {
|
||||
*axis_size = temp_axes_size;
|
||||
*axis_size = LongToInt(temp_axes_size);
|
||||
} else if (*axis_size != temp_axes_size) {
|
||||
MS_LOG(EXCEPTION) << "The `axes_size` of each argument in the scope of `vmap` should be equal, but got "
|
||||
<< *axis_size << " and " << temp_axes_size << ".";
|
||||
|
|
|
@ -80,7 +80,7 @@ AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const Primitiv
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: None.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 0);
|
||||
static AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
static const AbstractBasePtr abs_env = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
return abs_env;
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ class OrderedSet {
|
|||
OrderedSet &operator=(OrderedSet &&other) = default;
|
||||
|
||||
// insert an element to the OrderedSet after the given position.
|
||||
std::pair<iterator, bool> insert(iterator pos, const element_type &e) {
|
||||
std::pair<iterator, bool> insert(const iterator &pos, const element_type &e) {
|
||||
auto result = map_.emplace(e, ordered_data_.end());
|
||||
if (result.second) {
|
||||
result.first->second = ordered_data_.emplace(pos, e);
|
||||
|
|
Loading…
Reference in New Issue