!21741 clean code for master

Merge pull request !21741 from changzherui/clean_code_ma
This commit is contained in:
i-robot 2021-08-19 02:21:35 +00:00 committed by Gitee
commit 115da9f797
16 changed files with 88 additions and 40 deletions

View File

@ -148,7 +148,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer.
Example:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
prim_name = f'in `{prim_name}`' if prim_name else ''

View File

@ -178,6 +178,7 @@ void IrExportBuilder::BuildModelInfo() {
}
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
graph_proto->set_bprop_hash(func_graph->bprop_hash());
@ -225,7 +226,10 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::Gr
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_proto);
for (auto &item : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(item);
auto param = item->cast<ParameterPtr>();
if (param == nullptr) {
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
@ -296,6 +300,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
}
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
@ -328,6 +333,7 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name("value0");
auto data = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(data);
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
@ -362,6 +368,7 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
bool is_only_return = true;
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
continue;
@ -380,13 +387,16 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
}
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
if (node->size() != 2) {
MS_EXCEPTION_IF_NULL(node);
const int OutputSize = 2;
if (node->size() != OutputSize) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
MS_EXCEPTION_IF_NULL(last_node_);
last_node_->set_output(0, output_name);
SetValueInfoProto(arg, output_proto);
}
@ -396,9 +406,11 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
std::string type_name = "";
if (IsValueNode<Primitive>(node)) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(prim);
type_name = prim->ToString();
} else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
MS_EXCEPTION_IF_NULL(fg);
todo_.push_back(fg);
type_name = "REF::" + fg->ToString();
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
@ -416,10 +428,9 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
if (seq_string == nullptr) {
MS_LOG(EXCEPTION) << "seq_string is nullptr.";
}
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(seq_string);
if (type->isa<Tuple>()) {
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
@ -560,8 +571,9 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
}
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::string node_name = "";
if ((node != nullptr) && (node->func_graph() != nullptr)) {
if (node->func_graph() != nullptr) {
node_name = node->func_graph()->ToString() + ":";
}
if (node->isa<ValueNode>()) {
@ -578,7 +590,9 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro
if (node == nullptr || node_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
}
auto value = node->cast<ValueNodePtr>()->value();
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
node_proto->set_op_type("Constant");
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
@ -663,6 +677,9 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
}
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("scalar:value0");
if (value->isa<StringImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
@ -709,6 +726,9 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
}
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
if (value->isa<Int>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
@ -803,6 +823,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
return;
}
for (const auto &item : list_value->value()) {
MS_EXCEPTION_IF_NULL(item);
if (item->isa<ValueList>()) {
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
} else {

View File

@ -55,6 +55,7 @@ using Data = ge::op::Data;
namespace {
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg);
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> vecs;
@ -132,6 +133,7 @@ OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
}
void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
MS_EXCEPTION_IF_NULL(init_input);
if (this->training_) {
GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
@ -237,6 +239,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
for (auto &it : nodes) {
MS_EXCEPTION_IF_NULL(it);
if (it->isa<ValueNode>()) {
if (IsValueNode<SymbolicKeyInstance>(it)) {
auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
@ -251,6 +254,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
}
} else if (IsValueNode<RefKey>(it)) {
auto refkey = GetValueNode<RefKeyPtr>(it);
MS_EXCEPTION_IF_NULL(refkey);
auto name = refkey->tag();
auto iter = vars_.find(name); // get corresponding variable op
if (iter != vars_.end()) {
@ -771,9 +775,10 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu
case_inputs.emplace_back(node->input(i));
}
auto bnode = input_node->input(2)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bnode);
for (size_t i = 1; i < bnode->inputs().size(); i++) {
auto branch_node = bnode->input(i)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(branch_node);
for (size_t j = 2; j < branch_node->inputs().size(); j++) {
if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
case_inputs.emplace_back(branch_node->input(j));
@ -1073,7 +1078,9 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
}
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(node) == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "Can't find node in nodes_users.";
}
auto &users = manager->node_users()[node];
std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>();
std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
@ -1101,6 +1108,7 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
bool top) {
MS_EXCEPTION_IF_NULL(node);
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
@ -1356,6 +1364,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
return;
}
auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_node);
FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
DfGraphConvertor converter(anf_graph);
converter.use_inputs_ = true;
@ -1449,13 +1458,16 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
}
void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
auto value_ptr = node->input(2)->cast<ValueNodePtr>();
std::ostringstream ss;
ss << "op" << value_ptr.get();
op_draw_name_[value_ptr.get()] = ss.str();
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value();
MS_EXCEPTION_IF_NULL(value_ptr);
auto input_value = value_ptr->value();
auto int64_value = GetValue<int64_t>(input_value);
OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
auto op = adpt->generate(value_ptr);
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));

View File

@ -105,7 +105,7 @@ class Parameter(Tensor_):
>>> x = Tensor(np.ones((2, 1)), mindspore.float32)
>>> print(net(x))
[[2.]]
>>> _ = net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
>>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
>>> print(net(x))
[[0.]]
"""

View File

@ -1232,7 +1232,7 @@ class Tensor(Tensor_):
raise ValueError(msg)
class seed_context:
'''set and restore seed'''
"""Set and restore seed."""
def __init__(self, init):
self.init = init

View File

@ -493,6 +493,7 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
@ -570,6 +571,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
@ -794,9 +796,11 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(op_name);
} else {
prim = std::make_shared<Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type);
}
}
@ -940,10 +944,15 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
const std::string &out_tuple = output_node.name();
if (anfnode_build_map_.find(out_tuple) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << "Can't find out_tuple in anfnode_build_map_";
return false;
}
inputs.push_back(anfnode_build_map_[out_tuple]);
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
}
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(maketuple_ptr);
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
@ -992,8 +1001,7 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
}
}
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
return BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
}
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {

View File

@ -422,7 +422,6 @@ def _get_stack_info(frame):
Returns:
str, the string of the stack information.
"""
sinfo = None
stack_prefix = 'Stack (most recent call last):\n'
sinfo = stack_prefix + "".join(traceback.format_stack(frame))
return sinfo

View File

@ -339,7 +339,7 @@ class Cell(Cell_):
def run_construct(self, cast_inputs, kwargs):
if self.enable_hook:
output = self._hook_construct(*cast_inputs, **kwargs)
output = self._hook_construct(*cast_inputs)
else:
output = self.construct(*cast_inputs, **kwargs)
return output
@ -1209,7 +1209,7 @@ class Cell(Cell_):
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()
def _hook_construct(self, *inputs, **kwargs):
def _hook_construct(self, *inputs):
"""Hook construct method to replace original construct method when hook function enabled."""
inputs = self._backward_hook(*inputs)
inputs = self.construct(inputs)

View File

@ -81,9 +81,7 @@ def _column_or_1d(y):
Ravel column or 1d numpy array, otherwise raise an error.
"""
shape = np.shape(y)
if len(shape) == 1:
return np.ravel(y)
if len(shape) == 2 and shape[1] == 1:
if len(shape) == 1 or(len(shape) == 2 and shape[1] == 1):
return np.ravel(y)
raise ValueError("Bad input shape {0}.".format(shape))

View File

@ -113,13 +113,10 @@ def _check_param(momentum, frequency, lr, cls_name):
def caculate_device_shape(matrix_dim, channel, is_a):
ll = (0)
if is_a:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll

View File

@ -200,7 +200,7 @@ def rollaxis(x, axis, start=0):
axis = _check_axes_range(axis, ndim)
start = _check_start_normalize(start, ndim)
if start - axis >= 0 and start - axis <= 1:
if 0 <= start - axis <= 1:
return x
perm = F.make_range(0, ndim)
new_perm = None

View File

@ -43,6 +43,12 @@ class LossMonitor(Callback):
self._per_print_times = per_print_times
def step_end(self, run_context):
"""
Print training loss at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs

View File

@ -51,7 +51,6 @@ class LearningRateScheduler(Callback):
>>> dataset = create_custom_dataset("custom_dataset_path")
>>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
... dataset_sink_mode=False)
"""
def __init__(self, learning_rate_function):
@ -59,6 +58,12 @@ class LearningRateScheduler(Callback):
self.learning_rate_function = learning_rate_function
def step_end(self, run_context):
"""
Change the learning_rate at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
arr_lr = cb_params.optimizer.learning_rate.asnumpy()
lr = float(np.array2string(arr_lr))

View File

@ -38,9 +38,21 @@ class TimeMonitor(Callback):
self.epoch_time = time.time()
def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
Args:
run_context (RunContext): Context of the process running.
"""
self.epoch_time = time.time()
def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
Args:
run_context (RunContext): Context of the process running.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()

View File

@ -824,7 +824,6 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
index = 0
graphproto = graph_proto()
data_size = 0

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
from mindspore import context
class ConvertNetUtils():
class ConvertNetUtils:
"""
Convert net to thor layer net
"""
@ -29,7 +29,6 @@ class ConvertNetUtils():
nn.Embedding: ConvertNetUtils._convert_embedding,
nn.Conv2d: ConvertNetUtils._convert_conv2d}
@staticmethod
def _convert_dense(subcell):
"""
@ -64,7 +63,6 @@ class ConvertNetUtils():
new_subcell.bias = subcell.bias
return new_subcell
@staticmethod
def _convert_embedding(subcell):
"""
@ -76,7 +74,6 @@ class ConvertNetUtils():
new_subcell.embedding_table = subcell.embedding_table
return new_subcell
@staticmethod
def _convert_conv2d(subcell):
"""
@ -95,7 +92,6 @@ class ConvertNetUtils():
has_bias=has_bias, weight_init=weight)
return new_subcell
def _convert_to_thor_net(self, net):
"""
Convert net to thor net
@ -114,9 +110,6 @@ class ConvertNetUtils():
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell)
print("subcell name: ", name, "prefix is", prefix, flush=True)
if isinstance(new_subcell, (nn.DenseThor, nn.EmbeddingThor, nn.Conv2dThor)):
print("convert to thor layer success.", flush=True)
new_subcell.update_parameters_name(prefix + '.')
net.insert_child_to_cell(name, new_subcell)
change = True
@ -124,10 +117,8 @@ class ConvertNetUtils():
self._convert_to_thor_net(subcell)
if isinstance(net, nn.SequentialCell) and change:
print("is nn.SequentialCell and change")
net.cell_list = list(net.cells())
def convert_to_thor_net(self, net):
"""
This interface is used to convert a network to thor layer network, in order to calculate and store the
@ -152,7 +143,7 @@ class ConvertNetUtils():
net.update_cell_type("second-order")
class ConvertModelUtils():
class ConvertModelUtils:
"""
Convert model to thor model.
"""
@ -203,7 +194,7 @@ class ConvertModelUtils():
... frequency=100)
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
... metrics={'acc'}, amp_level="O2",
... loss_scale_manager=loss_manager,
... keep_batchnorm_fp32=False)