add flags on function

This commit is contained in:
Wei Luning 2020-06-22 23:21:01 +08:00
parent 8870956954
commit 8f56528f8c
15 changed files with 44 additions and 34 deletions

1
.gitignore vendored
View File

@ -26,6 +26,7 @@ cmake-build-debug
*_pb2.py
*.pb.h
*.pb.cc
*.pb
# Object files
*.o

View File

@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
std::string Tensor::GetShapeAndDataTypeInfo() const {
std::ostringstream buf;
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
return buf.str();
}
std::string Tensor::ToString() const {
const int small_tensor_size = 30;
std::ostringstream buf;
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
// only print small tensor
if (DataSize() < small_tensor_size) {
buf << "val:" << std::string(py::str(data()));

View File

@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
}
bool set_flag = ast_->UpdateFuncGraphFlags(current_fg);
bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
if (ast_->obj() != ast_->function()) {
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
}
if (!set_flag) {
MS_LOG(ERROR) << "Set flags failed";
return nullptr;
@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
return ret.cast<bool>();
}
bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "FuncGraph is null";
return false;
}
if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) {
if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
MS_LOG(DEBUG) << "No flags";
return true;
}
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
for (auto &item : flags) {
if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "Type error in flags dict convert";
@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
return false;
}
}
return true;
}

View File

@ -327,9 +327,6 @@ class ParseAst {
bool IsClassMember(const py::object &node);
// update the graph flags
bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph);
private:
// save obj,eg: class instance or function
py::object obj_;
@ -350,6 +347,9 @@ class ParseAst {
int function_line_offset_;
};
// update the graph flags
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
} // namespace parse

View File

@ -284,7 +284,6 @@ class ClipByNorm(Cell):
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select_ = P.Select()
self.greater_ = P.Greater()
self.axis = ()
self.cast = P.Cast()
self.zero = Tensor(np.array([0.0]).astype(np.float32))
self.sqrt = P.Sqrt()
@ -299,7 +298,7 @@ class ClipByNorm(Cell):
def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode"""
mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
cond = self.greater_(l2sum, self.zero)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)

View File

@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, data, label, sens=None):
weights = self.weights
loss = self.network(data, label)

View File

@ -30,16 +30,16 @@ from ...common.parameter import Parameter
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
def add_flags(fn, **flags):
def add_flags(fn=None, **flags):
"""
An interface to add flag for a function.
An decorator to add flag for a function.
Note:
Only supports bool value.
Args:
fn (Function): Function or cell to add flag.
flags (bool): Flags use kwargs.
fn (Function): Function or cell to add flag. Default: None.
flags (dict): Flags use kwargs. Default: None.
Returns:
Function, the fn added flags.
@ -47,11 +47,17 @@ def add_flags(fn, **flags):
Examples:
>>> add_flags(net, predit=True)
"""
def deco(fn):
# need set the attr and access on c++
if not hasattr(fn, "_mindspore_flags"):
fn._mindspore_flags = {}
fn._mindspore_flags.update({**flags})
return fn
ret = deco
if fn is not None:
ret = deco(fn)
return ret
def core(fn=None, **flags):

View File

@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self,
source_eos_ids,
source_eos_mask,

View File

@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
def __init__(self, config):
super(GetNextSentenceOutput, self).__init__()
self.log_softmax = _selected_ops.LogSoftmax()
self.weight_init = TruncatedNormal(config.initializer_range)
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2,
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype
self.cast = P.Cast()
@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)
@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,

View File

@ -17,6 +17,7 @@
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops.composite import add_flags
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
@ -121,6 +122,7 @@ class ASPP(nn.Cell):
self.feature_shape = feature_shape
self.concat = P.Concat(axis=1)
@add_flags(loop_can_unroll=True)
def construct(self, x, scale_index=0):
aspp0 = self.aspp0(x)
aspp1 = self.global_poolings[scale_index](x)
@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
atrous_rates=atrous_rates,
output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm)
self.aspp.add_flags(loop_can_unroll=True)
atrous_rates_len = 0
if atrous_rates is not None:
atrous_rates_len = len(atrous_rates)

View File

@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,

View File

@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
self.dtype = P.DType()
self.sub = P.Sub()
self.neg = P.Neg()
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, x):
init = self.alloc_status()
self.clear_status(init)

View File

@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.sub = P.Sub()
self.neg = P.Neg()
self.add_flags(has_effect=True)
@C.add_flags(has_effect=True)
def construct(self, x):
init = self.alloc_status()
self.clear_status(init)

View File

@ -14,13 +14,13 @@
# ============================================================================
""" test_lenet_model """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.nn import WithGradCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from ....ut_filter import non_graph_engine
class LeNet5(nn.Cell):
@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
return x
@non_graph_engine
@pytest.mark.skip(reason="need ge backend")
def test_lenet_pynative_train_net():
""" test_lenet_pynative_train_net """
data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)