forked from mindspore-Ecosystem/mindspore
add flags on function
This commit is contained in:
parent
8870956954
commit
8f56528f8c
|
@ -26,6 +26,7 @@ cmake-build-debug
|
||||||
*_pb2.py
|
*_pb2.py
|
||||||
*.pb.h
|
*.pb.h
|
||||||
*.pb.cc
|
*.pb.cc
|
||||||
|
*.pb
|
||||||
|
|
||||||
# Object files
|
# Object files
|
||||||
*.o
|
*.o
|
||||||
|
|
|
@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||||
}
|
}
|
||||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||||
if (!converted) {
|
if (!converted) {
|
||||||
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
|
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||||
}
|
}
|
||||||
(void)this->AddAttr(attr_name, converted_ret);
|
(void)this->AddAttr(attr_name, converted_ret);
|
||||||
}
|
}
|
||||||
|
|
|
@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
|
||||||
|
|
||||||
std::string Tensor::GetShapeAndDataTypeInfo() const {
|
std::string Tensor::GetShapeAndDataTypeInfo() const {
|
||||||
std::ostringstream buf;
|
std::ostringstream buf;
|
||||||
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
|
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||||
return buf.str();
|
return buf.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Tensor::ToString() const {
|
std::string Tensor::ToString() const {
|
||||||
const int small_tensor_size = 30;
|
const int small_tensor_size = 30;
|
||||||
std::ostringstream buf;
|
std::ostringstream buf;
|
||||||
buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString();
|
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
|
||||||
// only print small tensor
|
// only print small tensor
|
||||||
if (DataSize() < small_tensor_size) {
|
if (DataSize() < small_tensor_size) {
|
||||||
buf << "val:" << std::string(py::str(data()));
|
buf << "val:" << std::string(py::str(data()));
|
||||||
|
|
|
@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
||||||
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
|
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool set_flag = ast_->UpdateFuncGraphFlags(current_fg);
|
bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
|
||||||
|
if (ast_->obj() != ast_->function()) {
|
||||||
|
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
|
||||||
|
}
|
||||||
|
|
||||||
if (!set_flag) {
|
if (!set_flag) {
|
||||||
MS_LOG(ERROR) << "Set flags failed";
|
MS_LOG(ERROR) << "Set flags failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) {
|
||||||
return ret.cast<bool>();
|
return ret.cast<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
|
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "FuncGraph is null";
|
MS_LOG(ERROR) << "FuncGraph is null";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) {
|
if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
|
||||||
MS_LOG(DEBUG) << "No flags";
|
MS_LOG(DEBUG) << "No flags";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
|
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
|
||||||
for (auto &item : flags) {
|
for (auto &item : flags) {
|
||||||
if (!py::isinstance<py::str>(item.first)) {
|
if (!py::isinstance<py::str>(item.first)) {
|
||||||
MS_LOG(ERROR) << "Type error in flags dict convert";
|
MS_LOG(ERROR) << "Type error in flags dict convert";
|
||||||
|
@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -327,9 +327,6 @@ class ParseAst {
|
||||||
|
|
||||||
bool IsClassMember(const py::object &node);
|
bool IsClassMember(const py::object &node);
|
||||||
|
|
||||||
// update the graph flags
|
|
||||||
bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// save obj,eg: class instance or function
|
// save obj,eg: class instance or function
|
||||||
py::object obj_;
|
py::object obj_;
|
||||||
|
@ -350,6 +347,9 @@ class ParseAst {
|
||||||
int function_line_offset_;
|
int function_line_offset_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// update the graph flags
|
||||||
|
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
||||||
|
|
||||||
} // namespace parse
|
} // namespace parse
|
||||||
|
|
|
@ -284,7 +284,6 @@ class ClipByNorm(Cell):
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
self.select_ = P.Select()
|
self.select_ = P.Select()
|
||||||
self.greater_ = P.Greater()
|
self.greater_ = P.Greater()
|
||||||
self.axis = ()
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.zero = Tensor(np.array([0.0]).astype(np.float32))
|
self.zero = Tensor(np.array([0.0]).astype(np.float32))
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
|
@ -299,7 +298,7 @@ class ClipByNorm(Cell):
|
||||||
def construct(self, x, clip_norm):
|
def construct(self, x, clip_norm):
|
||||||
"""add ms_function decorator for pynative mode"""
|
"""add ms_function decorator for pynative mode"""
|
||||||
mul_x = F.square(x)
|
mul_x = F.square(x)
|
||||||
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
|
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
|
||||||
cond = self.greater_(l2sum, self.zero)
|
cond = self.greater_(l2sum, self.zero)
|
||||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||||
|
|
||||||
|
|
|
@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
name="loss_scale")
|
name="loss_scale")
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self, data, label, sens=None):
|
def construct(self, data, label, sens=None):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network(data, label)
|
loss = self.network(data, label)
|
||||||
|
|
|
@ -30,16 +30,16 @@ from ...common.parameter import Parameter
|
||||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||||
|
|
||||||
|
|
||||||
def add_flags(fn, **flags):
|
def add_flags(fn=None, **flags):
|
||||||
"""
|
"""
|
||||||
An interface to add flag for a function.
|
An decorator to add flag for a function.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Only supports bool value.
|
Only supports bool value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (Function): Function or cell to add flag.
|
fn (Function): Function or cell to add flag. Default: None.
|
||||||
flags (bool): Flags use kwargs.
|
flags (dict): Flags use kwargs. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function, the fn added flags.
|
Function, the fn added flags.
|
||||||
|
@ -47,11 +47,17 @@ def add_flags(fn, **flags):
|
||||||
Examples:
|
Examples:
|
||||||
>>> add_flags(net, predit=True)
|
>>> add_flags(net, predit=True)
|
||||||
"""
|
"""
|
||||||
# need set the attr and access on c++
|
def deco(fn):
|
||||||
if not hasattr(fn, "_mindspore_flags"):
|
# need set the attr and access on c++
|
||||||
fn._mindspore_flags = {}
|
if not hasattr(fn, "_mindspore_flags"):
|
||||||
fn._mindspore_flags.update({**flags})
|
fn._mindspore_flags = {}
|
||||||
return fn
|
|
||||||
|
fn._mindspore_flags.update({**flags})
|
||||||
|
return fn
|
||||||
|
ret = deco
|
||||||
|
if fn is not None:
|
||||||
|
ret = deco(fn)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def core(fn=None, **flags):
|
def core(fn=None, **flags):
|
||||||
|
|
|
@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
name="loss_scale")
|
name="loss_scale")
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self,
|
def construct(self,
|
||||||
source_eos_ids,
|
source_eos_ids,
|
||||||
source_eos_mask,
|
source_eos_mask,
|
||||||
|
|
|
@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(GetNextSentenceOutput, self).__init__()
|
super(GetNextSentenceOutput, self).__init__()
|
||||||
self.log_softmax = _selected_ops.LogSoftmax()
|
self.log_softmax = _selected_ops.LogSoftmax()
|
||||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
weight_init = TruncatedNormal(config.initializer_range)
|
||||||
self.dense = nn.Dense(config.hidden_size, 2,
|
self.dense = nn.Dense(config.hidden_size, 2,
|
||||||
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
|
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
|
||||||
self.dtype = config.dtype
|
self.dtype = config.dtype
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell):
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
succ = self.optimizer(grads)
|
succ = self.optimizer(grads)
|
||||||
return F.depend(loss, succ)
|
return F.depend(loss, succ)
|
||||||
|
|
||||||
|
@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
name="loss_scale")
|
name="loss_scale")
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self,
|
def construct(self,
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.composite import add_flags
|
||||||
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
|
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
|
||||||
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
|
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
|
||||||
|
|
||||||
|
@ -121,6 +122,7 @@ class ASPP(nn.Cell):
|
||||||
self.feature_shape = feature_shape
|
self.feature_shape = feature_shape
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
|
@add_flags(loop_can_unroll=True)
|
||||||
def construct(self, x, scale_index=0):
|
def construct(self, x, scale_index=0):
|
||||||
aspp0 = self.aspp0(x)
|
aspp0 = self.aspp0(x)
|
||||||
aspp1 = self.global_poolings[scale_index](x)
|
aspp1 = self.global_poolings[scale_index](x)
|
||||||
|
@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell):
|
||||||
atrous_rates=atrous_rates,
|
atrous_rates=atrous_rates,
|
||||||
output_stride=output_stride,
|
output_stride=output_stride,
|
||||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||||
self.aspp.add_flags(loop_can_unroll=True)
|
|
||||||
atrous_rates_len = 0
|
atrous_rates_len = 0
|
||||||
if atrous_rates is not None:
|
if atrous_rates is not None:
|
||||||
atrous_rates_len = len(atrous_rates)
|
atrous_rates_len = len(atrous_rates)
|
||||||
|
|
|
@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
if scale_update_cell:
|
if scale_update_cell:
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
name="loss_scale")
|
name="loss_scale")
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self,
|
def construct(self,
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
|
|
|
@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype():
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
self.neg = P.Neg()
|
self.neg = P.Neg()
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
init = self.alloc_status()
|
init = self.alloc_status()
|
||||||
self.clear_status(init)
|
self.clear_status(init)
|
||||||
|
|
|
@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell):
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
self.neg = P.Neg()
|
self.neg = P.Neg()
|
||||||
self.add_flags(has_effect=True)
|
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
init = self.alloc_status()
|
init = self.alloc_status()
|
||||||
self.clear_status(init)
|
self.clear_status(init)
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test_lenet_model """
|
""" test_lenet_model """
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn import WithGradCell, WithLossCell
|
from mindspore.nn import WithGradCell, WithLossCell
|
||||||
from mindspore.nn.optim import Momentum
|
from mindspore.nn.optim import Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from ....ut_filter import non_graph_engine
|
|
||||||
|
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
class LeNet5(nn.Cell):
|
||||||
|
@ -47,7 +47,7 @@ class LeNet5(nn.Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
@pytest.mark.skip(reason="need ge backend")
|
||||||
def test_lenet_pynative_train_net():
|
def test_lenet_pynative_train_net():
|
||||||
""" test_lenet_pynative_train_net """
|
""" test_lenet_pynative_train_net """
|
||||||
data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
|
data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
|
Loading…
Reference in New Issue