forked from mindspore-Ecosystem/mindspore
add flags on function
This commit is contained in:
parent
8870956954
commit
8f56528f8c
|
@ -26,6 +26,7 @@ cmake-build-debug
|
|||
*_pb2.py
|
||||
*.pb.h
|
||||
*.pb.cc
|
||||
*.pb
|
||||
|
||||
# Object files
|
||||
*.o
|
||||
|
|
|
@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
|
|||
|
||||
std::string Tensor::GetShapeAndDataTypeInfo() const {
|
||||
std::ostringstream buf;
|
||||
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
std::string Tensor::ToString() const {
|
||||
const int small_tensor_size = 30;
|
||||
std::ostringstream buf;
|
||||
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||
// only print small tensor
|
||||
if (DataSize() < small_tensor_size) {
|
||||
buf << "val:" << std::string(py::str(data()));
|
||||
|
|
|
@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
|
||||
}
|
||||
|
||||
bool set_flag = ast_->UpdateFuncGraphFlags(current_fg);
|
||||
bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
|
||||
if (ast_->obj() != ast_->function()) {
|
||||
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
|
||||
}
|
||||
|
||||
if (!set_flag) {
|
||||
MS_LOG(ERROR) << "Set flags failed";
|
||||
return nullptr;
|
||||
|
@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
|
|||
return ret.cast<bool>();
|
||||
}
|
||||
|
||||
bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
|
||||
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "FuncGraph is null";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) {
|
||||
if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
|
||||
MS_LOG(DEBUG) << "No flags";
|
||||
return true;
|
||||
}
|
||||
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
|
||||
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
|
||||
for (auto &item : flags) {
|
||||
if (!py::isinstance<py::str>(item.first)) {
|
||||
MS_LOG(ERROR) << "Type error in flags dict convert";
|
||||
|
@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -327,9 +327,6 @@ class ParseAst {
|
|||
|
||||
bool IsClassMember(const py::object &node);
|
||||
|
||||
// update the graph flags
|
||||
bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
// save obj,eg: class instance or function
|
||||
py::object obj_;
|
||||
|
@ -350,6 +347,9 @@ class ParseAst {
|
|||
int function_line_offset_;
|
||||
};
|
||||
|
||||
// update the graph flags
|
||||
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
|
||||
|
||||
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
||||
|
||||
} // namespace parse
|
||||
|
|
|
@ -284,7 +284,6 @@ class ClipByNorm(Cell):
|
|||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.select_ = P.Select()
|
||||
self.greater_ = P.Greater()
|
||||
self.axis = ()
|
||||
self.cast = P.Cast()
|
||||
self.zero = Tensor(np.array([0.0]).astype(np.float32))
|
||||
self.sqrt = P.Sqrt()
|
||||
|
@ -299,7 +298,7 @@ class ClipByNorm(Cell):
|
|||
def construct(self, x, clip_norm):
|
||||
"""add ms_function decorator for pynative mode"""
|
||||
mul_x = F.square(x)
|
||||
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
|
||||
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
|
||||
cond = self.greater_(l2sum, self.zero)
|
||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||
|
||||
|
|
|
@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, data, label, sens=None):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
|
|
|
@ -30,16 +30,16 @@ from ...common.parameter import Parameter
|
|||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
||||
|
||||
def add_flags(fn, **flags):
|
||||
def add_flags(fn=None, **flags):
|
||||
"""
|
||||
An interface to add flag for a function.
|
||||
An decorator to add flag for a function.
|
||||
|
||||
Note:
|
||||
Only supports bool value.
|
||||
|
||||
Args:
|
||||
fn (Function): Function or cell to add flag.
|
||||
flags (bool): Flags use kwargs.
|
||||
fn (Function): Function or cell to add flag. Default: None.
|
||||
flags (dict): Flags use kwargs. Default: None.
|
||||
|
||||
Returns:
|
||||
Function, the fn added flags.
|
||||
|
@ -47,11 +47,17 @@ def add_flags(fn, **flags):
|
|||
Examples:
|
||||
>>> add_flags(net, predit=True)
|
||||
"""
|
||||
def deco(fn):
|
||||
# need set the attr and access on c++
|
||||
if not hasattr(fn, "_mindspore_flags"):
|
||||
fn._mindspore_flags = {}
|
||||
|
||||
fn._mindspore_flags.update({**flags})
|
||||
return fn
|
||||
ret = deco
|
||||
if fn is not None:
|
||||
ret = deco(fn)
|
||||
return ret
|
||||
|
||||
|
||||
def core(fn=None, **flags):
|
||||
|
|
|
@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
|
|
|
@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
|
|||
def __init__(self, config):
|
||||
super(GetNextSentenceOutput, self).__init__()
|
||||
self.log_softmax = _selected_ops.LogSoftmax()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(config.hidden_size, 2,
|
||||
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
|
||||
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
|
||||
self.dtype = config.dtype
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
|
|||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
|
@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import add_flags
|
||||
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
|
||||
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
|
||||
|
||||
|
@ -121,6 +122,7 @@ class ASPP(nn.Cell):
|
|||
self.feature_shape = feature_shape
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
@add_flags(loop_can_unroll=True)
|
||||
def construct(self, x, scale_index=0):
|
||||
aspp0 = self.aspp0(x)
|
||||
aspp1 = self.global_poolings[scale_index](x)
|
||||
|
@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
|
|||
atrous_rates=atrous_rates,
|
||||
output_stride=output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||
self.aspp.add_flags(loop_can_unroll=True)
|
||||
|
||||
atrous_rates_len = 0
|
||||
if atrous_rates is not None:
|
||||
atrous_rates_len = len(atrous_rates)
|
||||
|
|
|
@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
|
|
|
@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
|
|||
self.dtype = P.DType()
|
||||
self.sub = P.Sub()
|
||||
self.neg = P.Neg()
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, x):
|
||||
init = self.alloc_status()
|
||||
self.clear_status(init)
|
||||
|
|
|
@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
|
|||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.sub = P.Sub()
|
||||
self.neg = P.Neg()
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, x):
|
||||
init = self.alloc_status()
|
||||
self.clear_status(init)
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# ============================================================================
|
||||
""" test_lenet_model """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import WithGradCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from ....ut_filter import non_graph_engine
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
|
@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
@pytest.mark.skip(reason="need ge backend")
|
||||
def test_lenet_pynative_train_net():
|
||||
""" test_lenet_pynative_train_net """
|
||||
data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
|
|
Loading…
Reference in New Issue