forked from mindspore-Ecosystem/mindspore
!21741 clean code for master
Merge pull request !21741 from changzherui/clean_code_ma
This commit is contained in:
commit
115da9f797
|
@ -148,7 +148,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
|
|||
Check argument integer.
|
||||
|
||||
Example:
|
||||
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
- number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
"""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
prim_name = f'in `{prim_name}`' if prim_name else ''
|
||||
|
|
|
@ -178,6 +178,7 @@ void IrExportBuilder::BuildModelInfo() {
|
|||
}
|
||||
|
||||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
||||
|
@ -225,7 +226,10 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr
|
|||
|
||||
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
bool save_tensor_data) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_proto);
|
||||
for (auto &item : func_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto param = item->cast<ParameterPtr>();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
||||
|
@ -296,6 +300,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|||
}
|
||||
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto elem_type = tensor->element();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
|
||||
|
@ -328,6 +333,7 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
|
|||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name("value0");
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
|
||||
auto dtype = data->data_type();
|
||||
auto shape = data->shape_c();
|
||||
|
@ -362,6 +368,7 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
|
|||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
bool is_only_return = true;
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
|
||||
continue;
|
||||
|
@ -380,13 +387,16 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
|
|||
}
|
||||
|
||||
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
||||
if (node->size() != 2) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const int OutputSize = 2;
|
||||
if (node->size() != OutputSize) {
|
||||
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
||||
}
|
||||
AnfNodePtr arg = node->input(1);
|
||||
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
output_proto->set_name(output_name);
|
||||
MS_EXCEPTION_IF_NULL(last_node_);
|
||||
last_node_->set_output(0, output_name);
|
||||
SetValueInfoProto(arg, output_proto);
|
||||
}
|
||||
|
@ -396,9 +406,11 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
std::string type_name = "";
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
type_name = prim->ToString();
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
todo_.push_back(fg);
|
||||
type_name = "REF::" + fg->ToString();
|
||||
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||
|
@ -416,10 +428,9 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
|
||||
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
if (seq_string == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "seq_string is nullptr.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(seq_string);
|
||||
if (type->isa<Tuple>()) {
|
||||
*seq_string += "Tuple[";
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
|
@ -560,8 +571,9 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::string node_name = "";
|
||||
if ((node != nullptr) && (node->func_graph() != nullptr)) {
|
||||
if (node->func_graph() != nullptr) {
|
||||
node_name = node->func_graph()->ToString() + ":";
|
||||
}
|
||||
if (node->isa<ValueNode>()) {
|
||||
|
@ -578,7 +590,9 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro
|
|||
if (node == nullptr || node_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
|
||||
}
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
node_proto->set_op_type("Constant");
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("value");
|
||||
|
@ -663,6 +677,9 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
|
|||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_ref_attr_name("scalar:value0");
|
||||
if (value->isa<StringImm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
||||
|
@ -709,6 +726,9 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
|
|||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<Int>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
|
@ -803,6 +823,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|||
return;
|
||||
}
|
||||
for (const auto &item : list_value->value()) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<ValueList>()) {
|
||||
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
|
||||
} else {
|
||||
|
|
|
@ -55,6 +55,7 @@ using Data = ge::op::Data;
|
|||
|
||||
namespace {
|
||||
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
|
@ -132,6 +133,7 @@ OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
|
|||
}
|
||||
|
||||
void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
|
||||
MS_EXCEPTION_IF_NULL(init_input);
|
||||
if (this->training_) {
|
||||
GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
|
||||
auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
|
||||
|
@ -237,6 +239,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
|
|||
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
||||
|
||||
for (auto &it : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(it);
|
||||
if (it->isa<ValueNode>()) {
|
||||
if (IsValueNode<SymbolicKeyInstance>(it)) {
|
||||
auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
|
||||
|
@ -251,6 +254,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
|
|||
}
|
||||
} else if (IsValueNode<RefKey>(it)) {
|
||||
auto refkey = GetValueNode<RefKeyPtr>(it);
|
||||
MS_EXCEPTION_IF_NULL(refkey);
|
||||
auto name = refkey->tag();
|
||||
auto iter = vars_.find(name); // get corresponding variable op
|
||||
if (iter != vars_.end()) {
|
||||
|
@ -771,9 +775,10 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu
|
|||
case_inputs.emplace_back(node->input(i));
|
||||
}
|
||||
auto bnode = input_node->input(2)->cast<CNodePtr>();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(bnode);
|
||||
for (size_t i = 1; i < bnode->inputs().size(); i++) {
|
||||
auto branch_node = bnode->input(i)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(branch_node);
|
||||
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
|
||||
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
|
||||
case_inputs.emplace_back(branch_node->input(j));
|
||||
|
@ -1073,7 +1078,9 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
|
|||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
if (manager->node_users().find(node) == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find node in nodes_users.";
|
||||
}
|
||||
auto &users = manager->node_users()[node];
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>();
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
|
||||
|
@ -1101,6 +1108,7 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
|
|||
|
||||
void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
|
||||
bool top) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
|
@ -1356,6 +1364,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
|
|||
return;
|
||||
}
|
||||
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(graph_node);
|
||||
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
|
||||
DfGraphConvertor converter(anf_graph);
|
||||
converter.use_inputs_ = true;
|
||||
|
@ -1449,13 +1458,16 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
|
|||
}
|
||||
|
||||
void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
|
||||
auto value_ptr = node->input(2)->cast<ValueNodePtr>();
|
||||
std::ostringstream ss;
|
||||
ss << "op" << value_ptr.get();
|
||||
op_draw_name_[value_ptr.get()] = ss.str();
|
||||
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
|
||||
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto input_value = value_ptr->value();
|
||||
auto int64_value = GetValue<int64_t>(input_value);
|
||||
OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
|
||||
auto op = adpt->generate(value_ptr);
|
||||
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));
|
||||
|
|
|
@ -105,7 +105,7 @@ class Parameter(Tensor_):
|
|||
>>> x = Tensor(np.ones((2, 1)), mindspore.float32)
|
||||
>>> print(net(x))
|
||||
[[2.]]
|
||||
>>> _ = net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
|
||||
>>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
|
||||
>>> print(net(x))
|
||||
[[0.]]
|
||||
"""
|
||||
|
|
|
@ -1232,7 +1232,7 @@ class Tensor(Tensor_):
|
|||
raise ValueError(msg)
|
||||
|
||||
class seed_context:
|
||||
'''set and restore seed'''
|
||||
"""Set and restore seed."""
|
||||
|
||||
def __init__(self, init):
|
||||
self.init = init
|
||||
|
|
|
@ -493,6 +493,7 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
|||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
|
@ -570,6 +571,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
|
|||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
|
@ -794,9 +796,11 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
|
|||
if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
|
||||
auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
|
||||
prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(op_name);
|
||||
} else {
|
||||
prim = std::make_shared<Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
}
|
||||
}
|
||||
|
@ -940,10 +944,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
if (anfnode_build_map_.find(out_tuple) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << "Can't find out_tuple in anfnode_build_map_";
|
||||
return false;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(maketuple_ptr);
|
||||
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
|
@ -992,8 +1001,7 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
|||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
return BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
|
||||
|
|
|
@ -422,7 +422,6 @@ def _get_stack_info(frame):
|
|||
Returns:
|
||||
str, the string of the stack information.
|
||||
"""
|
||||
sinfo = None
|
||||
stack_prefix = 'Stack (most recent call last):\n'
|
||||
sinfo = stack_prefix + "".join(traceback.format_stack(frame))
|
||||
return sinfo
|
||||
|
|
|
@ -339,7 +339,7 @@ class Cell(Cell_):
|
|||
|
||||
def run_construct(self, cast_inputs, kwargs):
|
||||
if self.enable_hook:
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
output = self._hook_construct(*cast_inputs)
|
||||
else:
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
return output
|
||||
|
@ -1209,7 +1209,7 @@ class Cell(Cell_):
|
|||
self.add_flags(auto_parallel=True)
|
||||
self._get_construct_inputs_number_and_name()
|
||||
|
||||
def _hook_construct(self, *inputs, **kwargs):
|
||||
def _hook_construct(self, *inputs):
|
||||
"""Hook construct method to replace original construct method when hook function enabled."""
|
||||
inputs = self._backward_hook(*inputs)
|
||||
inputs = self.construct(inputs)
|
||||
|
|
|
@ -81,9 +81,7 @@ def _column_or_1d(y):
|
|||
Ravel column or 1d numpy array, otherwise raise an error.
|
||||
"""
|
||||
shape = np.shape(y)
|
||||
if len(shape) == 1:
|
||||
return np.ravel(y)
|
||||
if len(shape) == 2 and shape[1] == 1:
|
||||
if len(shape) == 1 or(len(shape) == 2 and shape[1] == 1):
|
||||
return np.ravel(y)
|
||||
|
||||
raise ValueError("Bad input shape {0}.".format(shape))
|
||||
|
|
|
@ -113,13 +113,10 @@ def _check_param(momentum, frequency, lr, cls_name):
|
|||
|
||||
|
||||
def caculate_device_shape(matrix_dim, channel, is_a):
|
||||
ll = (0)
|
||||
if is_a:
|
||||
if channel // C0 == 0:
|
||||
matrix_dim = (matrix_dim / channel) * C0
|
||||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
else:
|
||||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
return ll
|
||||
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ def rollaxis(x, axis, start=0):
|
|||
|
||||
axis = _check_axes_range(axis, ndim)
|
||||
start = _check_start_normalize(start, ndim)
|
||||
if start - axis >= 0 and start - axis <= 1:
|
||||
if 0 <= start - axis <= 1:
|
||||
return x
|
||||
perm = F.make_range(0, ndim)
|
||||
new_perm = None
|
||||
|
|
|
@ -43,6 +43,12 @@ class LossMonitor(Callback):
|
|||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Print training loss at the end of step.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
|
|
|
@ -51,7 +51,6 @@ class LearningRateScheduler(Callback):
|
|||
>>> dataset = create_custom_dataset("custom_dataset_path")
|
||||
>>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
|
||||
... dataset_sink_mode=False)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate_function):
|
||||
|
@ -59,6 +58,12 @@ class LearningRateScheduler(Callback):
|
|||
self.learning_rate_function = learning_rate_function
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Change the learning_rate at the end of step.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
arr_lr = cb_params.optimizer.learning_rate.asnumpy()
|
||||
lr = float(np.array2string(arr_lr))
|
||||
|
|
|
@ -38,9 +38,21 @@ class TimeMonitor(Callback):
|
|||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""
|
||||
Record time at the begin of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running.
|
||||
"""
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Print process cost time at the end of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running.
|
||||
"""
|
||||
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
||||
step_size = self.data_size
|
||||
cb_params = run_context.original_args()
|
||||
|
|
|
@ -824,7 +824,6 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
if os.path.exists(data_path):
|
||||
shutil.rmtree(data_path)
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
||||
index = 0
|
||||
graphproto = graph_proto()
|
||||
data_size = 0
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore import context
|
||||
|
||||
|
||||
class ConvertNetUtils():
|
||||
class ConvertNetUtils:
|
||||
"""
|
||||
Convert net to thor layer net
|
||||
"""
|
||||
|
@ -29,7 +29,6 @@ class ConvertNetUtils():
|
|||
nn.Embedding: ConvertNetUtils._convert_embedding,
|
||||
nn.Conv2d: ConvertNetUtils._convert_conv2d}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_dense(subcell):
|
||||
"""
|
||||
|
@ -64,7 +63,6 @@ class ConvertNetUtils():
|
|||
new_subcell.bias = subcell.bias
|
||||
return new_subcell
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_embedding(subcell):
|
||||
"""
|
||||
|
@ -76,7 +74,6 @@ class ConvertNetUtils():
|
|||
new_subcell.embedding_table = subcell.embedding_table
|
||||
return new_subcell
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_conv2d(subcell):
|
||||
"""
|
||||
|
@ -95,7 +92,6 @@ class ConvertNetUtils():
|
|||
has_bias=has_bias, weight_init=weight)
|
||||
return new_subcell
|
||||
|
||||
|
||||
def _convert_to_thor_net(self, net):
|
||||
"""
|
||||
Convert net to thor net
|
||||
|
@ -114,9 +110,6 @@ class ConvertNetUtils():
|
|||
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = self._convert_method_map[type(subcell)](subcell)
|
||||
print("subcell name: ", name, "prefix is", prefix, flush=True)
|
||||
if isinstance(new_subcell, (nn.DenseThor, nn.EmbeddingThor, nn.Conv2dThor)):
|
||||
print("convert to thor layer success.", flush=True)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
net.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
|
@ -124,10 +117,8 @@ class ConvertNetUtils():
|
|||
self._convert_to_thor_net(subcell)
|
||||
|
||||
if isinstance(net, nn.SequentialCell) and change:
|
||||
print("is nn.SequentialCell and change")
|
||||
net.cell_list = list(net.cells())
|
||||
|
||||
|
||||
def convert_to_thor_net(self, net):
|
||||
"""
|
||||
This interface is used to convert a network to thor layer network, in order to calculate and store the
|
||||
|
@ -152,7 +143,7 @@ class ConvertNetUtils():
|
|||
net.update_cell_type("second-order")
|
||||
|
||||
|
||||
class ConvertModelUtils():
|
||||
class ConvertModelUtils:
|
||||
"""
|
||||
Convert model to thor model.
|
||||
"""
|
||||
|
@ -203,7 +194,7 @@ class ConvertModelUtils():
|
|||
... frequency=100)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
|
||||
... amp_level="O2", keep_batchnorm_fp32=False)
|
||||
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
... metrics={'acc'}, amp_level="O2",
|
||||
... loss_scale_manager=loss_manager,
|
||||
... keep_batchnorm_fp32=False)
|
||||
|
|
Loading…
Reference in New Issue