diff --git a/.gitignore b/.gitignore index 77ff222a1a..057169ec42 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ cmake-build-debug *_pb2.py *.pb.h *.pb.cc +*.pb # Object files *.o diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index 59497affd5..6ec27c2567 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { } bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); } (void)this->AddAttr(attr_name, converted_ret); } diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc index e5212e922d..4b02fdf2a7 100644 --- a/mindspore/ccsrc/ir/tensor.cc +++ b/mindspore/ccsrc/ir/tensor.cc @@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { std::string Tensor::GetShapeAndDataTypeInfo() const { std::ostringstream buf; - buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); return buf.str(); } std::string Tensor::ToString() const { const int small_tensor_size = 30; std::ostringstream buf; - buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); // only print small tensor if (DataSize() < small_tensor_size) { buf << "val:" << std::string(py::str(data())); diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 6d5c28c98c..66908240cb 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); } - bool set_flag = ast_->UpdateFuncGraphFlags(current_fg); + bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); + if (ast_->obj() != ast_->function()) { + set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); + } + if (!set_flag) { MS_LOG(ERROR) << "Set flags failed"; return nullptr; @@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) { return ret.cast(); } -bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "FuncGraph is null"; return false; } - if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) { + if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { MS_LOG(DEBUG) << "No flags"; return true; } - py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); + py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); for (auto &item : flags) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in flags dict convert"; @@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { return false; } } - return true; } diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h index 0a56ccaed9..19c503c6d0 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ b/mindspore/ccsrc/pipeline/parse/parse.h @@ -327,9 +327,6 @@ class ParseAst { bool IsClassMember(const py::object &node); - // update the graph flags - bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph); - private: // save obj,eg: class instance or function py::object obj_; @@ -350,6 +347,9 @@ class ParseAst { int function_line_offset_; }; +// update the graph flags +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); + AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); } // namespace parse diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index b1d5af48c9..d0e6904ec5 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -284,7 +284,6 @@ class ClipByNorm(Cell): self.reduce_sum = P.ReduceSum(keep_dims=True) self.select_ = P.Select() self.greater_ = P.Greater() - self.axis = () self.cast = P.Cast() self.zero = Tensor(np.array([0.0]).astype(np.float32)) self.sqrt = P.Sqrt() @@ -299,7 +298,7 @@ class ClipByNorm(Cell): def construct(self, x, clip_norm): """add ms_function decorator for pynative mode""" mul_x = F.square(x) - l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) + l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) cond = self.greater_(l2sum, self.zero) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 2bae6bbc5c..a9aa4d781b 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, data, label, sens=None): weights = self.weights loss = self.network(data, label) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index e283867684..b0f16d82bf 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -30,16 +30,16 @@ from ...common.parameter import Parameter __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] -def add_flags(fn, **flags): +def add_flags(fn=None, **flags): """ - An interface to add flag for a function. + An decorator to add flag for a function. Note: Only supports bool value. Args: - fn (Function): Function or cell to add flag. - flags (bool): Flags use kwargs. + fn (Function): Function or cell to add flag. Default: None. + flags (dict): Flags use kwargs. Default: None. Returns: Function, the fn added flags. @@ -47,11 +47,17 @@ def add_flags(fn, **flags): Examples: >>> add_flags(net, predit=True) """ - # need set the attr and access on c++ - if not hasattr(fn, "_mindspore_flags"): - fn._mindspore_flags = {} - fn._mindspore_flags.update({**flags}) - return fn + def deco(fn): + # need set the attr and access on c++ + if not hasattr(fn, "_mindspore_flags"): + fn._mindspore_flags = {} + + fn._mindspore_flags.update({**flags}) + return fn + ret = deco + if fn is not None: + ret = deco(fn) + return ret def core(fn=None, **flags): diff --git a/model_zoo/Transformer/src/transformer_for_train.py b/model_zoo/Transformer/src/transformer_for_train.py index 758ac65ab5..76237bee96 100644 --- a/model_zoo/Transformer/src/transformer_for_train.py +++ b/model_zoo/Transformer/src/transformer_for_train.py @@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, source_eos_ids, source_eos_mask, diff --git a/model_zoo/bert/src/bert_for_pre_training.py b/model_zoo/bert/src/bert_for_pre_training.py index 5e014f02ba..802391ee86 100644 --- a/model_zoo/bert/src/bert_for_pre_training.py +++ b/model_zoo/bert/src/bert_for_pre_training.py @@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell): def __init__(self, config): super(GetNextSentenceOutput, self).__init__() self.log_softmax = _selected_ops.LogSoftmax() - self.weight_init = TruncatedNormal(config.initializer_range) + weight_init = TruncatedNormal(config.initializer_range) self.dense = nn.Dense(config.hidden_size, 2, - weight_init=self.weight_init, has_bias=True).to_float(config.compute_type) + weight_init=weight_init, has_bias=True).to_float(config.compute_type) self.dtype = config.dtype self.cast = P.Cast() @@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell): if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) - succ = self.optimizer(grads) return F.depend(loss, succ) @@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, diff --git a/model_zoo/deeplabv3/src/deeplabv3.py b/model_zoo/deeplabv3/src/deeplabv3.py index 906a207302..03bb03ad14 100644 --- a/model_zoo/deeplabv3/src/deeplabv3.py +++ b/model_zoo/deeplabv3/src/deeplabv3.py @@ -17,6 +17,7 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P +from mindspore.ops.composite import add_flags from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ DepthwiseConv2dNative, SpaceToBatch, BatchToSpace @@ -121,6 +122,7 @@ class ASPP(nn.Cell): self.feature_shape = feature_shape self.concat = P.Concat(axis=1) + @add_flags(loop_can_unroll=True) def construct(self, x, scale_index=0): aspp0 = self.aspp0(x) aspp1 = self.global_poolings[scale_index](x) @@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell): atrous_rates=atrous_rates, output_stride=output_stride, fine_tune_batch_norm=fine_tune_batch_norm) - self.aspp.add_flags(loop_can_unroll=True) + atrous_rates_len = 0 if atrous_rates is not None: atrous_rates_len = len(atrous_rates) diff --git a/tests/st/networks/models/bert/src/bert_for_pre_training.py b/tests/st/networks/models/bert/src/bert_for_pre_training.py index 976f1a3c43..7c557a49c9 100644 --- a/tests/st/networks/models/bert/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/src/bert_for_pre_training.py @@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, diff --git a/tests/ut/python/keep_order/test_keep_order.py b/tests/ut/python/keep_order/test_keep_order.py index 1cf2b8e19a..fa0df6dd5d 100644 --- a/tests/ut/python/keep_order/test_keep_order.py +++ b/tests/ut/python/keep_order/test_keep_order.py @@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype(): self.dtype = P.DType() self.sub = P.Sub() self.neg = P.Neg() - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, x): init = self.alloc_status() self.clear_status(init) diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 09f113204e..21a3a3c9e1 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell): self.reduce_sum = P.ReduceSum(keep_dims=True) self.sub = P.Sub() self.neg = P.Neg() - self.add_flags(has_effect=True) + @C.add_flags(has_effect=True) def construct(self, x): init = self.alloc_status() self.clear_status(init) diff --git a/tests/ut/python/pynative_mode/ge/model/test_lenet_model.py b/tests/ut/python/pynative_mode/ge/model/test_lenet_model.py index f684a175c8..23999c398e 100644 --- a/tests/ut/python/pynative_mode/ge/model/test_lenet_model.py +++ b/tests/ut/python/pynative_mode/ge/model/test_lenet_model.py @@ -14,13 +14,13 @@ # ============================================================================ """ test_lenet_model """ import numpy as np +import pytest import mindspore.nn as nn from mindspore.common.tensor import Tensor from mindspore.nn import WithGradCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.ops import operations as P -from ....ut_filter import non_graph_engine class LeNet5(nn.Cell): @@ -47,7 +47,7 @@ class LeNet5(nn.Cell): return x -@non_graph_engine +@pytest.mark.skip(reason="need ge backend") def test_lenet_pynative_train_net(): """ test_lenet_pynative_train_net """ data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)