forked from mindspore-Ecosystem/mindspore
Allow tensor to be set const for network argument
This commit is contained in:
parent
47b9fd0a42
commit
06510b0649
|
@ -0,0 +1,15 @@
|
||||||
|
mindspore.Tensor.set_const_arg
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. py:method:: mindspore.Tensor.set_const_arg(const_arg=True)
|
||||||
|
|
||||||
|
指定该Tensor在作为网络入参时是否是一个常量。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **const_arg** (bool) - Tensor在作为网络入参时是否是一个常量。默认值:True。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,被指定了是否是一个常量网络入参。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果`const_arg`不是一个布尔值。
|
|
@ -11,6 +11,7 @@ mindspore.Tensor
|
||||||
- **shape** (Union[tuple, list, int]) - 用于定义该Tensor的形状。如果指定了 `input_data` ,则无需设置该参数。默认值:None。
|
- **shape** (Union[tuple, list, int]) - 用于定义该Tensor的形状。如果指定了 `input_data` ,则无需设置该参数。默认值:None。
|
||||||
- **init** (Initializer) - 用于在并行模式中延迟Tensor的数据的初始化,如果指定该参数,则 `dtype` 和 `shape` 也必须被指定。不推荐在非自动并行之外的场景下使用该接口。只有当调用 `Tensor.init_data` 时,才会使用指定的 `init` 来初始化Tensor数据。默认值:None。
|
- **init** (Initializer) - 用于在并行模式中延迟Tensor的数据的初始化,如果指定该参数,则 `dtype` 和 `shape` 也必须被指定。不推荐在非自动并行之外的场景下使用该接口。只有当调用 `Tensor.init_data` 时,才会使用指定的 `init` 来初始化Tensor数据。默认值:None。
|
||||||
- **internal** (bool) - Tensor是否由框架创建。如果为True,表示Tensor是由框架创建的,如果为False,表示Tensor是由用户创建的。默认值:False。
|
- **internal** (bool) - Tensor是否由框架创建。如果为True,表示Tensor是由框架创建的,如果为False,表示Tensor是由用户创建的。默认值:False。
|
||||||
|
- **const_arg** (bool) - 指定该Tensor作为网络输入时是否为常量。默认值:False。
|
||||||
|
|
||||||
输出:
|
输出:
|
||||||
Tensor。
|
Tensor。
|
||||||
|
@ -251,3 +252,4 @@ Parameter操作方法
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
|
|
||||||
mindspore.Tensor.flush_from_cache
|
mindspore.Tensor.flush_from_cache
|
||||||
|
mindspore.Tensor.set_const_arg
|
||||||
|
|
|
@ -16,7 +16,6 @@ mindspore.mutable
|
||||||
.. warning::
|
.. warning::
|
||||||
- 这是一个实验特性,未来有可能被修改或删除。
|
- 这是一个实验特性,未来有可能被修改或删除。
|
||||||
- 目前运行时暂时不支持处理标量数据流,所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。
|
- 目前运行时暂时不支持处理标量数据流,所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。
|
||||||
- Tensor默认就是可变的,当 `input_data` 为Tensor时,我们不做任何处理直接返回原Tensor。
|
|
||||||
- 当前暂时只支持在网络外部使用该接口。
|
- 当前暂时只支持在网络外部使用该接口。
|
||||||
- 当前该接口只在图模式下生效。
|
- 当前该接口只在图模式下生效。
|
||||||
|
|
||||||
|
|
|
@ -464,8 +464,7 @@ bool EnableGradForScalar(const AbstractBasePtr &abs) {
|
||||||
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
|
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
|
||||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||||
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
|
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
|
||||||
((*tuple_arg)[pos]->isa<AbstractUndetermined>() || (*tuple_arg)[pos]->BuildValue() == kAnyValue ||
|
((*tuple_arg)[pos]->BuildValue() == kAnyValue || EnableGradForScalar((*tuple_arg)[pos]));
|
||||||
EnableGradForScalar((*tuple_arg)[pos]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,
|
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,
|
||||||
|
|
|
@ -102,8 +102,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph, const std::vector<
|
||||||
|
|
||||||
AbstractBasePtr param_abs = param_node->abstract();
|
AbstractBasePtr param_abs = param_node->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(param_abs);
|
MS_EXCEPTION_IF_NULL(param_abs);
|
||||||
if (param_abs->isa<abstract::AbstractUndetermined>() || param_abs->BuildValue() == kAnyValue ||
|
if (param_abs->BuildValue() == kAnyValue || EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
|
||||||
EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
|
|
||||||
new_paras.push_back(param_node);
|
new_paras.push_back(param_node);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";
|
MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";
|
||||||
|
|
|
@ -233,6 +233,11 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||||
BroadenCNodeAbstract(resolved_graph);
|
BroadenCNodeAbstract(resolved_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasConstArgAttr(const py::object &obj) {
|
||||||
|
constexpr char const_arg_attr[] = "const_arg";
|
||||||
|
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) {
|
AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) {
|
||||||
// When the cell is set recomputed, it should not use old scope from cache.
|
// When the cell is set recomputed, it should not use old scope from cache.
|
||||||
MS_EXCEPTION_IF_NULL(origin_node);
|
MS_EXCEPTION_IF_NULL(origin_node);
|
||||||
|
@ -253,6 +258,10 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &
|
||||||
AnfNodePtr output = NewValueNode(convert_result);
|
AnfNodePtr output = NewValueNode(convert_result);
|
||||||
if (convert_result->isa<tensor::Tensor>()) {
|
if (convert_result->isa<tensor::Tensor>()) {
|
||||||
output = GetMixedPrecisionCastHelp(func_graph, output);
|
output = GetMixedPrecisionCastHelp(func_graph, output);
|
||||||
|
if (HasConstArgAttr(obj)) {
|
||||||
|
MS_LOG(WARNING) << "The tensor " << convert_result->ToString()
|
||||||
|
<< " which is not used for network input argument should not be set const.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,11 +129,30 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false, bool set_mutable = false) {
|
bool Mutable(const py::object &obj) {
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
constexpr char mutable_attr[] = "__ms_mutable__";
|
||||||
bool broaden = value->isa<MetaTensor>() || set_mutable || value->isa<MetaSparseTensor>() ||
|
return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
|
||||||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
|
}
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
|
||||||
|
bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
|
||||||
|
if (!value->isa<MetaTensor>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
constexpr char const_arg_attr[] = "const_arg";
|
||||||
|
return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnableTupleBroaden(const ValuePtr &value, bool enable_tuple_broaden) {
|
||||||
|
return enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradForScalar(const ValuePtr &value) {
|
||||||
|
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden = false) {
|
||||||
|
bool broaden = TensorArgMutable(arg, value) || Mutable(arg) || value->isa<MetaSparseTensor>() ||
|
||||||
|
EnableTupleBroaden(value, enable_tuple_broaden) || GradForScalar(value);
|
||||||
|
|
||||||
return abstract::FromValue(value, broaden);
|
return abstract::FromValue(value, broaden);
|
||||||
}
|
}
|
||||||
|
@ -208,33 +227,6 @@ void RecordInitStatus() {
|
||||||
|
|
||||||
void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
|
void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
|
||||||
|
|
||||||
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
|
||||||
if (abs->isa<abstract::AbstractTensor>()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto abs_sequence = abs->cast_ptr<abstract::AbstractSequence>();
|
|
||||||
if (abs_sequence != nullptr) {
|
|
||||||
const auto &elements = abs_sequence->elements();
|
|
||||||
for (auto &ele : elements) {
|
|
||||||
SetValueMutable(ele);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto abs_dict = abs->cast_ptr<abstract::AbstractDictionary>();
|
|
||||||
if (abs_dict != nullptr) {
|
|
||||||
const auto &elements = abs_dict->elements();
|
|
||||||
for (auto &ele : elements) {
|
|
||||||
SetValueMutable(ele.second);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
abs->set_value_mutable(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ToOrdinal(const size_t &i) {
|
std::string ToOrdinal(const size_t &i) {
|
||||||
auto suffix = "th";
|
auto suffix = "th";
|
||||||
if (i == kIndex1) {
|
if (i == kIndex1) {
|
||||||
|
@ -291,13 +283,7 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::tuple &args, bool ena
|
||||||
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
|
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
|
||||||
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
||||||
}
|
}
|
||||||
constexpr char mutable_attr[] = "__ms_mutable__";
|
AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
|
||||||
bool set_mutable = false;
|
|
||||||
if (py::hasattr(args[i], mutable_attr) && py::cast<bool>(py::getattr(args[i], mutable_attr))) {
|
|
||||||
SetValueMutable(converted->ToAbstract());
|
|
||||||
set_mutable = true;
|
|
||||||
}
|
|
||||||
AbstractBasePtr abs = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
|
|
||||||
(void)args_abs.emplace_back(abs);
|
(void)args_abs.emplace_back(abs);
|
||||||
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
|
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
|
||||||
// so we keep all inputs for subsequent procedure.
|
// so we keep all inputs for subsequent procedure.
|
||||||
|
@ -880,7 +866,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
||||||
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
||||||
}
|
}
|
||||||
(void)arguments.emplace_back(converted);
|
(void)arguments.emplace_back(converted);
|
||||||
auto args_abstract_item = ArgsToAbstract(converted, enable_tuple_broaden_);
|
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
|
||||||
if (is_auto_parallel) {
|
if (is_auto_parallel) {
|
||||||
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
|
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2134,6 +2134,10 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(converted_val);
|
MS_EXCEPTION_IF_NULL(converted_val);
|
||||||
|
if (converted_val->isa<tensor::Tensor>() && HasConstArgAttr(obj)) {
|
||||||
|
MS_LOG(WARNING) << "The tensor " << converted_val->ToString()
|
||||||
|
<< " which is not used for network input argument should not be set const.";
|
||||||
|
}
|
||||||
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
|
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
|
||||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||||
|
@ -2248,6 +2252,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||||
});
|
});
|
||||||
return std::make_shared<AbstractDictionary>(kv);
|
return std::make_shared<AbstractDictionary>(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasConstArgAttr(const py::object &obj) {
|
||||||
|
constexpr char const_arg_attr[] = "const_arg";
|
||||||
|
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class PartialEvaluator : public Evaluator {
|
class PartialEvaluator : public Evaluator {
|
||||||
|
|
|
@ -221,6 +221,27 @@ def _restore_mutable_attr(args_list, compile_args):
|
||||||
return new_compile_args
|
return new_compile_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_args_for_run(obj, args_list):
|
||||||
|
"""Get the actual input args for runtime."""
|
||||||
|
inputs = []
|
||||||
|
for i in args_list:
|
||||||
|
if isinstance(i, PythonTensor):
|
||||||
|
if i.has_init:
|
||||||
|
i.init_data()
|
||||||
|
if not i.const_arg:
|
||||||
|
inputs.append(i)
|
||||||
|
elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||||
|
inputs.append(i)
|
||||||
|
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
||||||
|
inputs.append(i)
|
||||||
|
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||||
|
inputs.append(i)
|
||||||
|
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||||
|
_check_all_tensor(i):
|
||||||
|
inputs.append(i)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
class _MindsporeFunctionExecutor:
|
class _MindsporeFunctionExecutor:
|
||||||
"""
|
"""
|
||||||
Represents a function compiled by graph compiler.
|
Represents a function compiled by graph compiler.
|
||||||
|
@ -443,17 +464,7 @@ class _MindsporeFunctionExecutor:
|
||||||
Returns:
|
Returns:
|
||||||
new_inputs, new input args, which are required for running.
|
new_inputs, new input args, which are required for running.
|
||||||
"""
|
"""
|
||||||
new_inputs = []
|
return _get_args_for_run(self, args_list)
|
||||||
for i in args_list:
|
|
||||||
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif self.enable_tuple_broaden and isinstance(i, tuple) and _check_all_tensor(i):
|
|
||||||
new_inputs.append(i)
|
|
||||||
return new_inputs
|
|
||||||
|
|
||||||
|
|
||||||
# The attributes used to identify a given object.
|
# The attributes used to identify a given object.
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._c_expression import Tensor as Tensor_
|
||||||
|
|
||||||
|
|
||||||
class _Tuple(tuple):
|
class _Tuple(tuple):
|
||||||
|
@ -42,7 +43,7 @@ def _check_all_tensor(value):
|
||||||
if not _check_all_tensor(element):
|
if not _check_all_tensor(element):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
return isinstance(value, Tensor)
|
return isinstance(value, Tensor_)
|
||||||
|
|
||||||
|
|
||||||
def mutable(input_data):
|
def mutable(input_data):
|
||||||
|
@ -67,8 +68,6 @@ def mutable(input_data):
|
||||||
- This is an experimental prototype that is subject to change or deletion.
|
- This is an experimental prototype that is subject to change or deletion.
|
||||||
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
|
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
|
||||||
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
|
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
|
||||||
- Tensor is mutable by default, when the `input_data` is Tensor, we just return the origin Tensor and nothing
|
|
||||||
is done.
|
|
||||||
- Currently we only support to use this api outside the network temporarily.
|
- Currently we only support to use this api outside the network temporarily.
|
||||||
- Currently this api only works in GRAPH mode.
|
- Currently this api only works in GRAPH mode.
|
||||||
|
|
||||||
|
@ -122,9 +121,6 @@ def mutable(input_data):
|
||||||
[ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))
|
[ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(input_data, Tensor):
|
|
||||||
return input_data
|
|
||||||
|
|
||||||
if not _check_all_tensor(input_data):
|
if not _check_all_tensor(input_data):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"For 'mutable', the 'input_data' should be one of (Tensor, tuple[Tensor], list[Tensor], dict[Tensor]) "
|
f"For 'mutable', the 'input_data' should be one of (Tensor, tuple[Tensor], list[Tensor], dict[Tensor]) "
|
||||||
|
@ -137,6 +133,11 @@ def mutable(input_data):
|
||||||
ret = _Tuple(input_data)
|
ret = _Tuple(input_data)
|
||||||
elif isinstance(input_data, dict):
|
elif isinstance(input_data, dict):
|
||||||
ret = _Dict(input_data)
|
ret = _Dict(input_data)
|
||||||
|
elif isinstance(input_data, Tensor):
|
||||||
|
ret.set_const_arg(False)
|
||||||
|
elif isinstance(input_data, Tensor_):
|
||||||
|
ret = Tensor(input_data, internal=True)
|
||||||
|
ret.set_const_arg(False)
|
||||||
|
|
||||||
setattr(ret, "__ms_mutable__", True)
|
setattr(ret, "__ms_mutable__", True)
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -60,6 +60,8 @@ class Tensor(Tensor_):
|
||||||
'True' means that the tensor is created by framework.
|
'True' means that the tensor is created by framework.
|
||||||
'False' means that the tensor is created by user.
|
'False' means that the tensor is created by user.
|
||||||
Default: False
|
Default: False
|
||||||
|
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor.
|
Tensor.
|
||||||
|
@ -116,7 +118,7 @@ class Tensor(Tensor_):
|
||||||
"""
|
"""
|
||||||
delta_seed = 0
|
delta_seed = 0
|
||||||
|
|
||||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False):
|
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
|
||||||
self.init_finished = False
|
self.init_finished = False
|
||||||
if internal:
|
if internal:
|
||||||
Tensor_.__init__(self, input_data)
|
Tensor_.__init__(self, input_data)
|
||||||
|
@ -166,6 +168,8 @@ class Tensor(Tensor_):
|
||||||
else:
|
else:
|
||||||
Tensor_.__init__(self, input_data)
|
Tensor_.__init__(self, input_data)
|
||||||
|
|
||||||
|
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
||||||
|
self.const_arg = const_arg
|
||||||
self.virtual_flag = False
|
self.virtual_flag = False
|
||||||
self.init = init
|
self.init = init
|
||||||
self.init_finished = True
|
self.init_finished = True
|
||||||
|
@ -191,6 +195,7 @@ class Tensor(Tensor_):
|
||||||
new_obj = Tensor(self)
|
new_obj = Tensor(self)
|
||||||
new_obj.init = self.init
|
new_obj.init = self.init
|
||||||
new_obj.virtual_flag = self.virtual_flag
|
new_obj.virtual_flag = self.virtual_flag
|
||||||
|
new_obj.const_arg = self.const_arg
|
||||||
return new_obj
|
return new_obj
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -445,6 +450,33 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
return Tensor(Tensor_.from_numpy(array))
|
return Tensor(Tensor_.from_numpy(array))
|
||||||
|
|
||||||
|
def set_const_arg(self, const_arg=True):
|
||||||
|
"""
|
||||||
|
Specify whether the tensor is a constant when it is used for the argument of a network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
||||||
|
Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has been specified whether to be a const network argument.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `const_arg` is not a bool.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||||
|
>>> x.set_const_arg(True)
|
||||||
|
"""
|
||||||
|
validator.check_value_type('const_arg', const_arg, bool, 'set_const_arg')
|
||||||
|
self.const_arg = const_arg
|
||||||
|
return self
|
||||||
|
|
||||||
def assign_value(self, value):
|
def assign_value(self, value):
|
||||||
"""
|
"""
|
||||||
Assign another tensor value to this tensor.
|
Assign another tensor value to this tensor.
|
||||||
|
|
|
@ -34,9 +34,9 @@ from mindspore import context
|
||||||
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
||||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
from mindspore.common.tensor import Tensor, CSRTensor, COOTensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops.operations import Cast
|
from mindspore.ops.operations import Cast
|
||||||
from mindspore.ops.primitive import Primitive
|
from mindspore.ops.primitive import Primitive
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
@ -965,22 +965,7 @@ class Cell(Cell_):
|
||||||
self._auto_parallel_compile_and_run = True
|
self._auto_parallel_compile_and_run = True
|
||||||
self.compile(*inputs)
|
self.compile(*inputs)
|
||||||
|
|
||||||
new_inputs = []
|
new_inputs = _get_args_for_run(self, inputs)
|
||||||
for i in inputs:
|
|
||||||
if isinstance(i, Tensor):
|
|
||||||
if i.has_init:
|
|
||||||
i.init_data()
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif isinstance(i, (CSRTensor, COOTensor)):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
|
||||||
new_inputs.append(i)
|
|
||||||
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
|
|
||||||
_check_all_tensor(i):
|
|
||||||
new_inputs.append(i)
|
|
||||||
|
|
||||||
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
|
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
|
||||||
|
|
||||||
def auto_parallel_compile_and_run(self):
|
def auto_parallel_compile_and_run(self):
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test mutable or constant tensor feature"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore import ms_function
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cal_constant_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get the matmul result for two constant tensor.
|
||||||
|
Expectation: Get the correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
net = Net()
|
||||||
|
output = net(x, y)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
expect_output = net(p, q)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cal_constant_tensor_ms_function():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get the matmul result for two constant tensor in ms_function.
|
||||||
|
Expectation: Get the correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def net(x, y):
|
||||||
|
out = P.MatMul()(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
output = net(x, y)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
expect_output = net(p, q)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_const_arg_tensor_to_mutable():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get gradient with respect to constant tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x, y)
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
grad_net = GradNetWrtX(Net())
|
||||||
|
# mutable api
|
||||||
|
output = grad_net(mutable(x), y)
|
||||||
|
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
# tensor set_const_arg api
|
||||||
|
x.set_const_arg(False)
|
||||||
|
output = grad_net(x, y)
|
||||||
|
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ms_function_grad_const_arg_tensor_to_mutable():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get gradient with respect to constant tensor input for ms_function.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
net = Net()
|
||||||
|
grad_op = GradOperation()
|
||||||
|
return grad_op(net)(x, y)
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
# mutable api
|
||||||
|
output = fn(mutable(x), y)
|
||||||
|
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
# tensor set_const_arg api
|
||||||
|
x.set_const_arg(False)
|
||||||
|
output = fn(x, y)
|
||||||
|
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,348 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test const tensor for network arg"""
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from mindspore.common import mutable
|
||||||
|
from mindspore.common.api import _CellGraphExecutor, _MindsporeFunctionExecutor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import ms_function
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_compile_phase1():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Test whether the compilation phase for tensor inputs twice are the same.
|
||||||
|
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Init the tensors as const arguments.
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
net = Net()
|
||||||
|
_cell_graph_executor = _CellGraphExecutor()
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 != phase2
|
||||||
|
# mutable api
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||||
|
assert phase1 == phase2
|
||||||
|
# set_mutable api of Tensor
|
||||||
|
x.set_const_arg(False)
|
||||||
|
y.set_const_arg(False)
|
||||||
|
p.set_const_arg(False)
|
||||||
|
q.set_const_arg(False)
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_function_tensor_compile_phase1():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Test whether the compilation phase for tensor inputs twice are the same of ms_function.
|
||||||
|
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
out = P.MatMul()(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
ms_create_time = int(time.time() * 1e9)
|
||||||
|
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||||
|
# The ms_function makes the tensor inputs mutable by default
|
||||||
|
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||||
|
assert phase1 != phase2
|
||||||
|
# mutable api
|
||||||
|
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||||
|
assert phase1 == phase2
|
||||||
|
# set_mutable api of Tensor
|
||||||
|
x.set_const_arg(False)
|
||||||
|
y.set_const_arg(False)
|
||||||
|
p.set_const_arg(False)
|
||||||
|
q.set_const_arg(False)
|
||||||
|
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||||
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_compile_phase2():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Test whether the compilation phase for constant tensor inputs twice are the same.
|
||||||
|
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
net = Net()
|
||||||
|
_cell_graph_executor = _CellGraphExecutor()
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 == phase2
|
||||||
|
# Set const arg.
|
||||||
|
x.set_const_arg()
|
||||||
|
y.set_const_arg()
|
||||||
|
p.set_const_arg()
|
||||||
|
q.set_const_arg()
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 != phase2
|
||||||
|
# mutable api
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||||
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_function_tensor_compile_phase2():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Test whether the compilation phase for constant tensor inputs twice are the same of ms_function.
|
||||||
|
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
out = P.MatMul()(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
ms_create_time = int(time.time() * 1e9)
|
||||||
|
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||||
|
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||||
|
assert phase1 == phase2
|
||||||
|
# Set const arg.
|
||||||
|
x.set_const_arg()
|
||||||
|
y.set_const_arg()
|
||||||
|
p.set_const_arg()
|
||||||
|
q.set_const_arg()
|
||||||
|
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||||
|
assert phase1 != phase2
|
||||||
|
# mutable api
|
||||||
|
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||||
|
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||||
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
|
def test_grad_constant_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get gradient with respect to the constant tensor input.
|
||||||
|
Expectation: Get an empty gradient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x, y)
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
grad_net = GradNetWrtX(Net())
|
||||||
|
output = grad_net(x, y)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
assert output == ()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_function_grad_constant_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get gradient with respect to the constant tensor input of ms_function.
|
||||||
|
Expectation: Get an empty gradient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
net = Net()
|
||||||
|
grad_op = GradOperation()
|
||||||
|
return grad_op(net)(x, y)
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
output = fn(x, y)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
assert output == ()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_constant_folding():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get result of add operator for two constant tensor by constant folding in frontend.
|
||||||
|
Expectation: Get a correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.add = P.Add()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.add(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
net = Net()
|
||||||
|
output = net(x, y)
|
||||||
|
expect_output = np.array([[0.51, 0.9, 1.5],
|
||||||
|
[1.3, 1.5, 2.4]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_function_tensor_constant_folding():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get result of add operator of ms_function for two constant tensor by constant folding in frontend.
|
||||||
|
Expectation: Get a correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
return P.Add()(x, y)
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
|
||||||
|
output = fn(x, y)
|
||||||
|
expect_output = np.array([[0.51, 0.9, 1.5],
|
||||||
|
[1.3, 1.5, 2.4]]).astype(np.float32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_tensor_if():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get result of control flow with if for constant tensor.
|
||||||
|
Expectation: Get the correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.z = Tensor([3], dtype=mstype.int32)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = y
|
||||||
|
if x < self.z:
|
||||||
|
out = out + y
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([0], dtype=mstype.int32, const_arg=True)
|
||||||
|
y = Tensor([1], dtype=mstype.int32, const_arg=True)
|
||||||
|
net = Net()
|
||||||
|
output = net(x, y)
|
||||||
|
expect_output = np.array([2]).astype(np.int32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_function_constant_tensor_if():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Get result of control flow with if of ms_function for constant tensor.
|
||||||
|
Expectation: Get the correct result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def fn(x, y):
|
||||||
|
z = Tensor([3], dtype=mstype.int32)
|
||||||
|
out = y
|
||||||
|
if x < z:
|
||||||
|
out = out + y
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([0], dtype=mstype.int32, const_arg=True)
|
||||||
|
y = Tensor([1], dtype=mstype.int32, const_arg=True)
|
||||||
|
output = fn(x, y)
|
||||||
|
expect_output = np.array([2]).astype(np.int32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_mutable_value():
|
||||||
|
"""
|
||||||
|
Feature: Set mutable tensor input to constant.
|
||||||
|
Description: Check the illegal arg.
|
||||||
|
Expectation: Raise the correct error log.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
x = Tensor([0], dtype=mstype.int32, const_arg=1)
|
||||||
|
except TypeError as e:
|
||||||
|
assert str(e) == "For 'Tensor', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."
|
||||||
|
|
||||||
|
try:
|
||||||
|
x = Tensor([0], dtype=mstype.int32)
|
||||||
|
x.set_const_arg(1)
|
||||||
|
except TypeError as e:
|
||||||
|
assert str(e) == "For 'set_const_arg', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""test mutable"""
|
"""test mutable"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from mindspore.ops.composite import GradOperation
|
from mindspore.ops.composite import GradOperation
|
||||||
|
@ -23,6 +22,7 @@ from mindspore.ops import operations as P
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
from mindspore._c_expression import Tensor as Tensor_
|
||||||
from mindspore import Parameter
|
from mindspore import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
@ -243,6 +243,49 @@ def test_dict_inputs_compile_phase():
|
||||||
assert phase1 == phase2
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_inputs_compile_phase():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Test whether the compilation phase for Tensor input twice are the same.
|
||||||
|
Expectation: The phases are the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
net = Net()
|
||||||
|
_cell_graph_executor = _CellGraphExecutor()
|
||||||
|
# tuple of Tensor
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 == phase2
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||||
|
assert phase1 == phase2
|
||||||
|
x = Tensor_(x)
|
||||||
|
y = Tensor_(y)
|
||||||
|
p = Tensor_(p)
|
||||||
|
q = Tensor_(q)
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||||
|
assert phase1 == phase2
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||||
|
assert phase1 == phase2
|
||||||
|
|
||||||
|
|
||||||
def test_check_mutable_value():
|
def test_check_mutable_value():
|
||||||
"""
|
"""
|
||||||
Feature: Set Constants mutable.
|
Feature: Set Constants mutable.
|
||||||
|
|
Loading…
Reference in New Issue