clean code

This commit is contained in:
liutongtong 2022-07-26 16:55:51 +08:00
parent 89e3a499b1
commit 6e6950f51e
3 changed files with 10 additions and 9 deletions

View File

@ -665,8 +665,8 @@ std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph)
auto switch_node = FindLoopSwitchNode(control_subgraph);
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
const auto &control_params = control_subgraph->parameters();
size_t auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < loop_partial_node->inputs().size(); ++i) {
int64_t auxiliary_inputs_num = 2;
for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < loop_partial_node->inputs().size(); ++i) {
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
auto control_param_pos =
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
@ -683,8 +683,8 @@ std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
const auto &loop_params = loop_partial_node->inputs();
size_t auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < after_partial_node->inputs().size(); ++i) {
int64_t auxiliary_inputs_num = 2;
for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < after_partial_node->inputs().size(); ++i) {
auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
result.push_back(after_param_pos - auxiliary_inputs_num);
@ -1675,7 +1675,7 @@ void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &, const CNodePtr &
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
auto end_ignore_mask = GetOpAttribute<int64_t>(node, "end_mask");
for (size_t i = 0; i < end_value.size(); ++i) {
if ((end_ignore_mask & (1 << i)) != 0) {
if (((unsigned int)end_ignore_mask & (1 << i)) != 0) {
end_value[i] = x_shape[i];
}
}

View File

@ -865,7 +865,7 @@ std::vector<Operator> DfGraphConvertor::GetWhileBodyOutputs() {
}
if (j->isa<Parameter>()) {
size_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
int64_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
auto idx_cond = body_cond_map_[idx];
if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
@ -1140,7 +1140,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
for (size_t i = 0; i < body_params.size(); i++) {
auto p = body_params[i];
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
body_cond_map_[i] = idx;
MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
}
@ -1157,7 +1157,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
for (size_t i = 0; i < after_params.size(); i++) {
auto p = after_params[i];
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
after_cond_map_[i] = idx;
MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
}

View File

@ -47,6 +47,7 @@ from . import amp
from ..common.api import _pynative_executor
from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
def _transfer_tensor_to_tuple(inputs):
"""
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
@ -71,7 +72,7 @@ def _save_final_ckpt(func):
def wrapper(self, *args, **kwargs):
obj = None
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
obj = kwargs['callbacks']
obj = kwargs.get('callbacks')
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
for item in kwargs.get('callbacks'):
if isinstance(item, ModelCheckpoint):