add flags on function

This commit is contained in:
Wei Luning 2020-06-22 23:21:01 +08:00
parent 8870956954
commit 8f56528f8c
15 changed files with 44 additions and 34 deletions

1
.gitignore vendored
View File

@ -26,6 +26,7 @@ cmake-build-debug
*_pb2.py *_pb2.py
*.pb.h *.pb.h
*.pb.cc *.pb.cc
*.pb
# Object files # Object files
*.o *.o

View File

@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
} }
bool converted = parse::ConvertData(obj, &converted_ret); bool converted = parse::ConvertData(obj, &converted_ret);
if (!converted) { if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
} }
(void)this->AddAttr(attr_name, converted_ret); (void)this->AddAttr(attr_name, converted_ret);
} }

View File

@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
std::string Tensor::GetShapeAndDataTypeInfo() const { std::string Tensor::GetShapeAndDataTypeInfo() const {
std::ostringstream buf; std::ostringstream buf;
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
return buf.str(); return buf.str();
} }
std::string Tensor::ToString() const { std::string Tensor::ToString() const {
const int small_tensor_size = 30; const int small_tensor_size = 30;
std::ostringstream buf; std::ostringstream buf;
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
// only print small tensor // only print small tensor
if (DataSize() < small_tensor_size) { if (DataSize() < small_tensor_size) {
buf << "val:" << std::string(py::str(data())); buf << "val:" << std::string(py::str(data()));

View File

@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
} }
bool set_flag = ast_->UpdateFuncGraphFlags(current_fg); bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
if (ast_->obj() != ast_->function()) {
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
}
if (!set_flag) { if (!set_flag) {
MS_LOG(ERROR) << "Set flags failed"; MS_LOG(ERROR) << "Set flags failed";
return nullptr; return nullptr;
@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
return ret.cast<bool>(); return ret.cast<bool>();
} }
bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "FuncGraph is null"; MS_LOG(ERROR) << "FuncGraph is null";
return false; return false;
} }
if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) { if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
MS_LOG(DEBUG) << "No flags"; MS_LOG(DEBUG) << "No flags";
return true; return true;
} }
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
for (auto &item : flags) { for (auto &item : flags) {
if (!py::isinstance<py::str>(item.first)) { if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "Type error in flags dict convert"; MS_LOG(ERROR) << "Type error in flags dict convert";
@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
return false; return false;
} }
} }
return true; return true;
} }

View File

@ -327,9 +327,6 @@ class ParseAst {
bool IsClassMember(const py::object &node); bool IsClassMember(const py::object &node);
// update the graph flags
bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph);
private: private:
// save obj,eg: class instance or function // save obj,eg: class instance or function
py::object obj_; py::object obj_;
@ -350,6 +347,9 @@ class ParseAst {
int function_line_offset_; int function_line_offset_;
}; };
// update the graph flags
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
} // namespace parse } // namespace parse

View File

@ -284,7 +284,6 @@ class ClipByNorm(Cell):
self.reduce_sum = P.ReduceSum(keep_dims=True) self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select_ = P.Select() self.select_ = P.Select()
self.greater_ = P.Greater() self.greater_ = P.Greater()
self.axis = ()
self.cast = P.Cast() self.cast = P.Cast()
self.zero = Tensor(np.array([0.0]).astype(np.float32)) self.zero = Tensor(np.array([0.0]).astype(np.float32))
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
@ -299,7 +298,7 @@ class ClipByNorm(Cell):
def construct(self, x, clip_norm): def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode""" """add ms_function decorator for pynative mode"""
mul_x = F.square(x) mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
cond = self.greater_(l2sum, self.zero) cond = self.greater_(l2sum, self.zero)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)

View File

@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
if scale_update_cell: if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale") name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, data, label, sens=None): def construct(self, data, label, sens=None):
weights = self.weights weights = self.weights
loss = self.network(data, label) loss = self.network(data, label)

View File

@ -30,16 +30,16 @@ from ...common.parameter import Parameter
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
def add_flags(fn, **flags): def add_flags(fn=None, **flags):
""" """
An interface to add flag for a function. An decorator to add flag for a function.
Note: Note:
Only supports bool value. Only supports bool value.
Args: Args:
fn (Function): Function or cell to add flag. fn (Function): Function or cell to add flag. Default: None.
flags (bool): Flags use kwargs. flags (dict): Flags use kwargs. Default: None.
Returns: Returns:
Function, the fn added flags. Function, the fn added flags.
@ -47,11 +47,17 @@ def add_flags(fn, **flags):
Examples: Examples:
>>> add_flags(net, predit=True) >>> add_flags(net, predit=True)
""" """
# need set the attr and access on c++ def deco(fn):
if not hasattr(fn, "_mindspore_flags"): # need set the attr and access on c++
fn._mindspore_flags = {} if not hasattr(fn, "_mindspore_flags"):
fn._mindspore_flags.update({**flags}) fn._mindspore_flags = {}
return fn
fn._mindspore_flags.update({**flags})
return fn
ret = deco
if fn is not None:
ret = deco(fn)
return ret
def core(fn=None, **flags): def core(fn=None, **flags):

View File

@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell: if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale") name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, def construct(self,
source_eos_ids, source_eos_ids,
source_eos_mask, source_eos_mask,

View File

@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
def __init__(self, config): def __init__(self, config):
super(GetNextSentenceOutput, self).__init__() super(GetNextSentenceOutput, self).__init__()
self.log_softmax = _selected_ops.LogSoftmax() self.log_softmax = _selected_ops.LogSoftmax()
self.weight_init = TruncatedNormal(config.initializer_range) weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2, self.dense = nn.Dense(config.hidden_size, 2,
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type) weight_init=weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype self.dtype = config.dtype
self.cast = P.Cast() self.cast = P.Cast()
@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
if self.reducer_flag: if self.reducer_flag:
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
succ = self.optimizer(grads) succ = self.optimizer(grads)
return F.depend(loss, succ) return F.depend(loss, succ)
@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell: if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale") name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, def construct(self,
input_ids, input_ids,
input_mask, input_mask,

View File

@ -17,6 +17,7 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.composite import add_flags
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
@ -121,6 +122,7 @@ class ASPP(nn.Cell):
self.feature_shape = feature_shape self.feature_shape = feature_shape
self.concat = P.Concat(axis=1) self.concat = P.Concat(axis=1)
@add_flags(loop_can_unroll=True)
def construct(self, x, scale_index=0): def construct(self, x, scale_index=0):
aspp0 = self.aspp0(x) aspp0 = self.aspp0(x)
aspp1 = self.global_poolings[scale_index](x) aspp1 = self.global_poolings[scale_index](x)
@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
atrous_rates=atrous_rates, atrous_rates=atrous_rates,
output_stride=output_stride, output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm)
self.aspp.add_flags(loop_can_unroll=True)
atrous_rates_len = 0 atrous_rates_len = 0
if atrous_rates is not None: if atrous_rates is not None:
atrous_rates_len = len(atrous_rates) atrous_rates_len = len(atrous_rates)

View File

@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell: if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale") name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, def construct(self,
input_ids, input_ids,
input_mask, input_mask,

View File

@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
self.dtype = P.DType() self.dtype = P.DType()
self.sub = P.Sub() self.sub = P.Sub()
self.neg = P.Neg() self.neg = P.Neg()
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, x): def construct(self, x):
init = self.alloc_status() init = self.alloc_status()
self.clear_status(init) self.clear_status(init)

View File

@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
self.reduce_sum = P.ReduceSum(keep_dims=True) self.reduce_sum = P.ReduceSum(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.neg = P.Neg() self.neg = P.Neg()
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, x): def construct(self, x):
init = self.alloc_status() init = self.alloc_status()
self.clear_status(init) self.clear_status(init)

View File

@ -14,13 +14,13 @@
# ============================================================================ # ============================================================================
""" test_lenet_model """ """ test_lenet_model """
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn import WithGradCell, WithLossCell from mindspore.nn import WithGradCell, WithLossCell
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ....ut_filter import non_graph_engine
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
return x return x
@non_graph_engine @pytest.mark.skip(reason="need ge backend")
def test_lenet_pynative_train_net(): def test_lenet_pynative_train_net():
""" test_lenet_pynative_train_net """ """ test_lenet_pynative_train_net """
data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)