clean code
This commit is contained in:
parent
89e3a499b1
commit
6e6950f51e
|
@ -665,8 +665,8 @@ std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph)
|
||||||
auto switch_node = FindLoopSwitchNode(control_subgraph);
|
auto switch_node = FindLoopSwitchNode(control_subgraph);
|
||||||
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
||||||
const auto &control_params = control_subgraph->parameters();
|
const auto &control_params = control_subgraph->parameters();
|
||||||
size_t auxiliary_inputs_num = 2;
|
int64_t auxiliary_inputs_num = 2;
|
||||||
for (size_t i = auxiliary_inputs_num; i < loop_partial_node->inputs().size(); ++i) {
|
for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < loop_partial_node->inputs().size(); ++i) {
|
||||||
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
|
auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
|
||||||
auto control_param_pos =
|
auto control_param_pos =
|
||||||
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
|
std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
|
||||||
|
@ -683,8 +683,8 @@ std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
|
||||||
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
|
||||||
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
|
auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
|
||||||
const auto &loop_params = loop_partial_node->inputs();
|
const auto &loop_params = loop_partial_node->inputs();
|
||||||
size_t auxiliary_inputs_num = 2;
|
int64_t auxiliary_inputs_num = 2;
|
||||||
for (size_t i = auxiliary_inputs_num; i < after_partial_node->inputs().size(); ++i) {
|
for (size_t i = static_cast<unsigned int>(auxiliary_inputs_num); i < after_partial_node->inputs().size(); ++i) {
|
||||||
auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
|
auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
|
||||||
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
|
auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
|
||||||
result.push_back(after_param_pos - auxiliary_inputs_num);
|
result.push_back(after_param_pos - auxiliary_inputs_num);
|
||||||
|
@ -1675,7 +1675,7 @@ void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &, const CNodePtr &
|
||||||
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
|
const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
|
||||||
auto end_ignore_mask = GetOpAttribute<int64_t>(node, "end_mask");
|
auto end_ignore_mask = GetOpAttribute<int64_t>(node, "end_mask");
|
||||||
for (size_t i = 0; i < end_value.size(); ++i) {
|
for (size_t i = 0; i < end_value.size(); ++i) {
|
||||||
if ((end_ignore_mask & (1 << i)) != 0) {
|
if (((unsigned int)end_ignore_mask & (1 << i)) != 0) {
|
||||||
end_value[i] = x_shape[i];
|
end_value[i] = x_shape[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -865,7 +865,7 @@ std::vector<Operator> DfGraphConvertor::GetWhileBodyOutputs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (j->isa<Parameter>()) {
|
if (j->isa<Parameter>()) {
|
||||||
size_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
|
int64_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
|
||||||
auto idx_cond = body_cond_map_[idx];
|
auto idx_cond = body_cond_map_[idx];
|
||||||
if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
|
if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
|
||||||
while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
|
while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
|
||||||
|
@ -1140,7 +1140,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
|
||||||
|
|
||||||
for (size_t i = 0; i < body_params.size(); i++) {
|
for (size_t i = 0; i < body_params.size(); i++) {
|
||||||
auto p = body_params[i];
|
auto p = body_params[i];
|
||||||
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
|
int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
|
||||||
body_cond_map_[i] = idx;
|
body_cond_map_[i] = idx;
|
||||||
MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
|
MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
|
||||||
}
|
}
|
||||||
|
@ -1157,7 +1157,7 @@ void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
|
||||||
|
|
||||||
for (size_t i = 0; i < after_params.size(); i++) {
|
for (size_t i = 0; i < after_params.size(); i++) {
|
||||||
auto p = after_params[i];
|
auto p = after_params[i];
|
||||||
size_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
|
int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
|
||||||
after_cond_map_[i] = idx;
|
after_cond_map_[i] = idx;
|
||||||
MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
|
MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,7 @@ from . import amp
|
||||||
from ..common.api import _pynative_executor
|
from ..common.api import _pynative_executor
|
||||||
from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
||||||
|
|
||||||
|
|
||||||
def _transfer_tensor_to_tuple(inputs):
|
def _transfer_tensor_to_tuple(inputs):
|
||||||
"""
|
"""
|
||||||
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
||||||
|
@ -71,7 +72,7 @@ def _save_final_ckpt(func):
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
obj = None
|
obj = None
|
||||||
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
||||||
obj = kwargs['callbacks']
|
obj = kwargs.get('callbacks')
|
||||||
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
||||||
for item in kwargs.get('callbacks'):
|
for item in kwargs.get('callbacks'):
|
||||||
if isinstance(item, ModelCheckpoint):
|
if isinstance(item, ModelCheckpoint):
|
||||||
|
|
Loading…
Reference in New Issue