clean code

This commit is contained in:
liutongtong 2022-07-26 16:55:51 +08:00
parent 89e3a499b1
commit 6e6950f51e
3 changed files with 10 additions and 9 deletions

View File

@ -665,8 +665,8 @@ std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph)
auto switch_node = FindLoopSwitchNode(control_subgraph); auto switch_node = FindLoopSwitchNode(control_subgraph);
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum); auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
const auto &control_params = control_subgraph->parameters(); const auto &control_params = control_subgraph->parameters();
size_t auxiliary_inputs_num = 2; int64_t auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < loop_partial_node->inputs().size(); ++i) { for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < loop_partial_node->inputs().size(); ++i) {
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i); auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
auto control_param_pos = auto control_param_pos =
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin(); std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
@ -683,8 +683,8 @@ std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum); auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum); auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
const auto &loop_params = loop_partial_node->inputs(); const auto &loop_params = loop_partial_node->inputs();
size_t auxiliary_inputs_num = 2; int64_t auxiliary_inputs_num = 2;
for (size_t i = auxiliary_inputs_num; i < after_partial_node->inputs().size(); ++i) { for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < after_partial_node->inputs().size(); ++i) {
auto after_param = GetNodeInput<Parameter>(after_partial_node, i); auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin(); auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
result.push_back(after_param_pos - auxiliary_inputs_num); result.push_back(after_param_pos - auxiliary_inputs_num);
@ -1675,7 +1675,7 @@ void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &, const CNodePtr &
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape(); const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
auto end_ignore_mask = GetOpAttribute<int64_t>(node, "end_mask"); auto end_ignore_mask = GetOpAttribute<int64_t>(node, "end_mask");
for (size_t i = 0; i < end_value.size(); ++i) { for (size_t i = 0; i < end_value.size(); ++i) {
if ((end_ignore_mask & (1 << i)) != 0) { if (((unsigned int)end_ignore_mask & (1 << i)) != 0) {
end_value[i] = x_shape[i]; end_value[i] = x_shape[i];
} }
} }

View File

@ -865,7 +865,7 @@ std::vector<Operator> DfGraphConvertor::GetWhileBodyOutputs() {
} }
if (j->isa<Parameter>()) { if (j->isa<Parameter>()) {
size_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin(); int64_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
auto idx_cond = body_cond_map_[idx]; auto idx_cond = body_cond_map_[idx];
if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() || if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) { while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
@ -1140,7 +1140,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
for (size_t i = 0; i < body_params.size(); i++) { for (size_t i = 0; i < body_params.size(); i++) {
auto p = body_params[i]; auto p = body_params[i];
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin(); int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
body_cond_map_[i] = idx; body_cond_map_[i] = idx;
MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx; MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
} }
@ -1157,7 +1157,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
for (size_t i = 0; i < after_params.size(); i++) { for (size_t i = 0; i < after_params.size(); i++) {
auto p = after_params[i]; auto p = after_params[i];
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin(); int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
after_cond_map_[i] = idx; after_cond_map_[i] = idx;
MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx; MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
} }

View File

@ -47,6 +47,7 @@ from . import amp
from ..common.api import _pynative_executor from ..common.api import _pynative_executor
from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
def _transfer_tensor_to_tuple(inputs): def _transfer_tensor_to_tuple(inputs):
""" """
If the input is a tensor, convert it to a tuple. If not, the output is unchanged. If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
@ -71,7 +72,7 @@ def _save_final_ckpt(func):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
obj = None obj = None
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint): if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
obj = kwargs['callbacks'] obj = kwargs.get('callbacks')
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list): if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
for item in kwargs.get('callbacks'): for item in kwargs.get('callbacks'):
if isinstance(item, ModelCheckpoint): if isinstance(item, ModelCheckpoint):