Allow tensor to be set const for network argument

This commit is contained in:
yujianfeng 2022-07-04 17:17:47 +08:00
parent 47b9fd0a42
commit 06510b0649
16 changed files with 682 additions and 82 deletions

View File

@ -0,0 +1,15 @@
mindspore.Tensor.set_const_arg
==============================
.. py:method:: mindspore.Tensor.set_const_arg(const_arg=True)
指定该Tensor在作为网络入参时是否是一个常量。
参数:
- **const_arg** (bool) - Tensor在作为网络入参时是否是一个常量。默认值True。
返回:
Tensor被指定了是否是一个常量网络入参。
异常:
- **TypeError** - 如果`const_arg`不是一个布尔值。

View File

@ -11,6 +11,7 @@ mindspore.Tensor
- **shape** (Union[tuple, list, int]) - 用于定义该Tensor的形状。如果指定了 `input_data` 则无需设置该参数。默认值None。 - **shape** (Union[tuple, list, int]) - 用于定义该Tensor的形状。如果指定了 `input_data` 则无需设置该参数。默认值None。
- **init** (Initializer) - 用于在并行模式中延迟Tensor的数据的初始化如果指定该参数`dtype``shape` 也必须被指定。不推荐在非自动并行之外的场景下使用该接口。只有当调用 `Tensor.init_data` 时,才会使用指定的 `init` 来初始化Tensor数据。默认值None。 - **init** (Initializer) - 用于在并行模式中延迟Tensor的数据的初始化如果指定该参数`dtype``shape` 也必须被指定。不推荐在非自动并行之外的场景下使用该接口。只有当调用 `Tensor.init_data` 时,才会使用指定的 `init` 来初始化Tensor数据。默认值None。
- **internal** (bool) - Tensor是否由框架创建。如果为True表示Tensor是由框架创建的如果为False表示Tensor是由用户创建的。默认值False。 - **internal** (bool) - Tensor是否由框架创建。如果为True表示Tensor是由框架创建的如果为False表示Tensor是由用户创建的。默认值False。
- **const_arg** (bool) - 指定该Tensor作为网络输入时是否为常量。默认值False。
输出: 输出:
Tensor。 Tensor。
@ -251,3 +252,4 @@ Parameter操作方法
:nosignatures: :nosignatures:
mindspore.Tensor.flush_from_cache mindspore.Tensor.flush_from_cache
mindspore.Tensor.set_const_arg

View File

@ -16,7 +16,6 @@ mindspore.mutable
.. warning:: .. warning::
- 这是一个实验特性,未来有可能被修改或删除。 - 这是一个实验特性,未来有可能被修改或删除。
- 目前运行时暂时不支持处理标量数据流所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。 - 目前运行时暂时不支持处理标量数据流所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。
- Tensor默认就是可变的`input_data` 为Tensor时我们不做任何处理直接返回原Tensor。
- 当前暂时只支持在网络外部使用该接口。 - 当前暂时只支持在网络外部使用该接口。
- 当前该接口只在图模式下生效。 - 当前该接口只在图模式下生效。

View File

@ -464,8 +464,7 @@ bool EnableGradForScalar(const AbstractBasePtr &abs) {
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) { bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
MS_EXCEPTION_IF_NULL(tuple_arg); MS_EXCEPTION_IF_NULL(tuple_arg);
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr && return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
((*tuple_arg)[pos]->isa<AbstractUndetermined>() || (*tuple_arg)[pos]->BuildValue() == kAnyValue || ((*tuple_arg)[pos]->BuildValue() == kAnyValue || EnableGradForScalar((*tuple_arg)[pos]));
EnableGradForScalar((*tuple_arg)[pos]));
} }
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg, void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,

View File

@ -102,8 +102,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph, const std::vector<
AbstractBasePtr param_abs = param_node->abstract(); AbstractBasePtr param_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(param_abs); MS_EXCEPTION_IF_NULL(param_abs);
if (param_abs->isa<abstract::AbstractUndetermined>() || param_abs->BuildValue() == kAnyValue || if (param_abs->BuildValue() == kAnyValue || EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
new_paras.push_back(param_node); new_paras.push_back(param_node);
} else { } else {
MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument."; MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";

View File

@ -233,6 +233,11 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
BroadenCNodeAbstract(resolved_graph); BroadenCNodeAbstract(resolved_graph);
} }
bool HasConstArgAttr(const py::object &obj) {
constexpr char const_arg_attr[] = "const_arg";
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
}
AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) { AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) {
// When the cell is set recomputed, it should not use old scope from cache. // When the cell is set recomputed, it should not use old scope from cache.
MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(origin_node);
@ -253,6 +258,10 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &
AnfNodePtr output = NewValueNode(convert_result); AnfNodePtr output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) { if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output); output = GetMixedPrecisionCastHelp(func_graph, output);
if (HasConstArgAttr(obj)) {
MS_LOG(WARNING) << "The tensor " << convert_result->ToString()
<< " which is not used for network input argument should not be set const.";
}
} }
return output; return output;
} }

View File

@ -129,11 +129,30 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
return true; return true;
} }
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false, bool set_mutable = false) { bool Mutable(const py::object &obj) {
MS_EXCEPTION_IF_NULL(value); constexpr char mutable_attr[] = "__ms_mutable__";
bool broaden = value->isa<MetaTensor>() || set_mutable || value->isa<MetaSparseTensor>() || return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) || }
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
if (!value->isa<MetaTensor>()) {
return false;
}
constexpr char const_arg_attr[] = "const_arg";
return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
}
bool EnableTupleBroaden(const ValuePtr &value, bool enable_tuple_broaden) {
return enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>());
}
bool GradForScalar(const ValuePtr &value) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>();
}
AbstractBasePtr ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden = false) {
bool broaden = TensorArgMutable(arg, value) || Mutable(arg) || value->isa<MetaSparseTensor>() ||
EnableTupleBroaden(value, enable_tuple_broaden) || GradForScalar(value);
return abstract::FromValue(value, broaden); return abstract::FromValue(value, broaden);
} }
@ -208,33 +227,6 @@ void RecordInitStatus() {
void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; } void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
return;
}
auto abs_sequence = abs->cast_ptr<abstract::AbstractSequence>();
if (abs_sequence != nullptr) {
const auto &elements = abs_sequence->elements();
for (auto &ele : elements) {
SetValueMutable(ele);
}
return;
}
auto abs_dict = abs->cast_ptr<abstract::AbstractDictionary>();
if (abs_dict != nullptr) {
const auto &elements = abs_dict->elements();
for (auto &ele : elements) {
SetValueMutable(ele.second);
}
return;
}
abs->set_value_mutable(true);
}
std::string ToOrdinal(const size_t &i) { std::string ToOrdinal(const size_t &i) {
auto suffix = "th"; auto suffix = "th";
if (i == kIndex1) { if (i == kIndex1) {
@ -291,13 +283,7 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::tuple &args, bool ena
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is " MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'."; << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
} }
constexpr char mutable_attr[] = "__ms_mutable__"; AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
bool set_mutable = false;
if (py::hasattr(args[i], mutable_attr) && py::cast<bool>(py::getattr(args[i], mutable_attr))) {
SetValueMutable(converted->ToAbstract());
set_mutable = true;
}
AbstractBasePtr abs = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
(void)args_abs.emplace_back(abs); (void)args_abs.emplace_back(abs);
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph, // The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
// so we keep all inputs for subsequent procedure. // so we keep all inputs for subsequent procedure.
@ -880,7 +866,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]); MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
} }
(void)arguments.emplace_back(converted); (void)arguments.emplace_back(converted);
auto args_abstract_item = ArgsToAbstract(converted, enable_tuple_broaden_); auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
if (is_auto_parallel) { if (is_auto_parallel) {
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i); (void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
} }

View File

@ -2134,6 +2134,10 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
MS_LOG(EXCEPTION) << "Convert the python object failed"; MS_LOG(EXCEPTION) << "Convert the python object failed";
} }
MS_EXCEPTION_IF_NULL(converted_val); MS_EXCEPTION_IF_NULL(converted_val);
if (converted_val->isa<tensor::Tensor>() && HasConstArgAttr(obj)) {
MS_LOG(WARNING) << "The tensor " << converted_val->ToString()
<< " which is not used for network input argument should not be set const.";
}
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf); AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>()); auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result); evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
@ -2248,6 +2252,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
}); });
return std::make_shared<AbstractDictionary>(kv); return std::make_shared<AbstractDictionary>(kv);
} }
bool HasConstArgAttr(const py::object &obj) {
constexpr char const_arg_attr[] = "const_arg";
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
}
}; };
class PartialEvaluator : public Evaluator { class PartialEvaluator : public Evaluator {

View File

@ -221,6 +221,27 @@ def _restore_mutable_attr(args_list, compile_args):
return new_compile_args return new_compile_args
def _get_args_for_run(obj, args_list):
"""Get the actual input args for runtime."""
inputs = []
for i in args_list:
if isinstance(i, PythonTensor):
if i.has_init:
i.init_data()
if not i.const_arg:
inputs.append(i)
elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
inputs.append(i)
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
inputs.append(i)
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
_check_all_tensor(i):
inputs.append(i)
return inputs
class _MindsporeFunctionExecutor: class _MindsporeFunctionExecutor:
""" """
Represents a function compiled by graph compiler. Represents a function compiled by graph compiler.
@ -443,17 +464,7 @@ class _MindsporeFunctionExecutor:
Returns: Returns:
new_inputs, new input args, which are required for running. new_inputs, new input args, which are required for running.
""" """
new_inputs = [] return _get_args_for_run(self, args_list)
for i in args_list:
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
new_inputs.append(i)
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
elif self.enable_tuple_broaden and isinstance(i, tuple) and _check_all_tensor(i):
new_inputs.append(i)
return new_inputs
# The attributes used to identify a given object. # The attributes used to identify a given object.

View File

@ -16,6 +16,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._c_expression import Tensor as Tensor_
class _Tuple(tuple): class _Tuple(tuple):
@ -42,7 +43,7 @@ def _check_all_tensor(value):
if not _check_all_tensor(element): if not _check_all_tensor(element):
return False return False
return True return True
return isinstance(value, Tensor) return isinstance(value, Tensor_)
def mutable(input_data): def mutable(input_data):
@ -67,8 +68,6 @@ def mutable(input_data):
- This is an experimental prototype that is subject to change or deletion. - This is an experimental prototype that is subject to change or deletion.
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor], - The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now. list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
- Tensor is mutable by default, when the `input_data` is Tensor, we just return the origin Tensor and nothing
is done.
- Currently we only support to use this api outside the network temporarily. - Currently we only support to use this api outside the network temporarily.
- Currently this api only works in GRAPH mode. - Currently this api only works in GRAPH mode.
@ -122,9 +121,6 @@ def mutable(input_data):
[ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]])) [ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))
""" """
if isinstance(input_data, Tensor):
return input_data
if not _check_all_tensor(input_data): if not _check_all_tensor(input_data):
raise TypeError( raise TypeError(
f"For 'mutable', the 'input_data' should be one of (Tensor, tuple[Tensor], list[Tensor], dict[Tensor]) " f"For 'mutable', the 'input_data' should be one of (Tensor, tuple[Tensor], list[Tensor], dict[Tensor]) "
@ -137,6 +133,11 @@ def mutable(input_data):
ret = _Tuple(input_data) ret = _Tuple(input_data)
elif isinstance(input_data, dict): elif isinstance(input_data, dict):
ret = _Dict(input_data) ret = _Dict(input_data)
elif isinstance(input_data, Tensor):
ret.set_const_arg(False)
elif isinstance(input_data, Tensor_):
ret = Tensor(input_data, internal=True)
ret.set_const_arg(False)
setattr(ret, "__ms_mutable__", True) setattr(ret, "__ms_mutable__", True)
return ret return ret

View File

@ -60,6 +60,8 @@ class Tensor(Tensor_):
'True' means that the tensor is created by framework. 'True' means that the tensor is created by framework.
'False' means that the tensor is created by user. 'False' means that the tensor is created by user.
Default: False Default: False
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
Default: False.
Outputs: Outputs:
Tensor. Tensor.
@ -116,7 +118,7 @@ class Tensor(Tensor_):
""" """
delta_seed = 0 delta_seed = 0
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False): def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
self.init_finished = False self.init_finished = False
if internal: if internal:
Tensor_.__init__(self, input_data) Tensor_.__init__(self, input_data)
@ -166,6 +168,8 @@ class Tensor(Tensor_):
else: else:
Tensor_.__init__(self, input_data) Tensor_.__init__(self, input_data)
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
self.const_arg = const_arg
self.virtual_flag = False self.virtual_flag = False
self.init = init self.init = init
self.init_finished = True self.init_finished = True
@ -191,6 +195,7 @@ class Tensor(Tensor_):
new_obj = Tensor(self) new_obj = Tensor(self)
new_obj.init = self.init new_obj.init = self.init
new_obj.virtual_flag = self.virtual_flag new_obj.virtual_flag = self.virtual_flag
new_obj.const_arg = self.const_arg
return new_obj return new_obj
def __repr__(self): def __repr__(self):
@ -445,6 +450,33 @@ class Tensor(Tensor_):
return Tensor(Tensor_.from_numpy(array)) return Tensor(Tensor_.from_numpy(array))
def set_const_arg(self, const_arg=True):
"""
Specify whether the tensor is a constant when it is used for the argument of a network.
Args:
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
Default: True.
Returns:
Tensor, has been specified whether to be a const network argument.
Raises:
TypeError: If `const_arg` is not a bool.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
>>> x.set_const_arg(True)
"""
validator.check_value_type('const_arg', const_arg, bool, 'set_const_arg')
self.const_arg = const_arg
return self
def assign_value(self, value): def assign_value(self, value):
""" """
Assign another tensor value to this tensor. Assign another tensor value to this tensor.

View File

@ -34,9 +34,9 @@ from mindspore import context
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
from mindspore._checkparam import Validator from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor, CSRTensor, COOTensor from mindspore.common.tensor import Tensor
from mindspore.ops.operations import Cast from mindspore.ops.operations import Cast
from mindspore.ops.primitive import Primitive from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
@ -965,22 +965,7 @@ class Cell(Cell_):
self._auto_parallel_compile_and_run = True self._auto_parallel_compile_and_run = True
self.compile(*inputs) self.compile(*inputs)
new_inputs = [] new_inputs = _get_args_for_run(self, inputs)
for i in inputs:
if isinstance(i, Tensor):
if i.has_init:
i.init_data()
new_inputs.append(i)
elif isinstance(i, (CSRTensor, COOTensor)):
new_inputs.append(i)
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
_check_all_tensor(i):
new_inputs.append(i)
return _cell_graph_executor(self, *new_inputs, phase=self.phase) return _cell_graph_executor(self, *new_inputs, phase=self.phase)
def auto_parallel_compile_and_run(self): def auto_parallel_compile_and_run(self):

View File

@ -0,0 +1,162 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test mutable or constant tensor feature"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.composite import GradOperation
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common import mutable
from mindspore import ms_function
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cal_constant_tensor():
"""
Feature: Set mutable tensor input to constant.
Description: Get the matmul result for two constant tensor.
Expectation: Get the correct result.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
net = Net()
output = net(x, y)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
expect_output = net(p, q)
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cal_constant_tensor_ms_function():
"""
Feature: Set mutable tensor input to constant.
Description: Get the matmul result for two constant tensor in ms_function.
Expectation: Get the correct result.
"""
@ms_function
def net(x, y):
out = P.MatMul()(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
output = net(x, y)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
expect_output = net(p, q)
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_const_arg_tensor_to_mutable():
"""
Feature: Set mutable tensor input to constant.
Description: Get gradient with respect to constant tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
grad_net = GradNetWrtX(Net())
# mutable api
output = grad_net(mutable(x), y)
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)
# tensor set_const_arg api
x.set_const_arg(False)
output = grad_net(x, y)
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ms_function_grad_const_arg_tensor_to_mutable():
"""
Feature: Set mutable tensor input to constant.
Description: Get gradient with respect to constant tensor input for ms_function.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
@ms_function
def fn(x, y):
net = Net()
grad_op = GradOperation()
return grad_op(net)(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
# mutable api
output = fn(mutable(x), y)
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)
# tensor set_const_arg api
x.set_const_arg(False)
output = fn(x, y)
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,348 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test const tensor for network arg"""
import time
import numpy as np
from mindspore.ops.composite import GradOperation
from mindspore.common import mutable
from mindspore.common.api import _CellGraphExecutor, _MindsporeFunctionExecutor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import ms_function
def test_tensor_compile_phase1():
"""
Feature: Set mutable tensor input to constant.
Description: Test whether the compilation phase for tensor inputs twice are the same.
Expectation: The phases are the same only when the tensor inputs are set mutable.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
# Init the tensors as const arguments.
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 != phase2
# mutable api
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
assert phase1 == phase2
# set_mutable api of Tensor
x.set_const_arg(False)
y.set_const_arg(False)
p.set_const_arg(False)
q.set_const_arg(False)
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 == phase2
def test_ms_function_tensor_compile_phase1():
"""
Feature: Set mutable tensor input to constant.
Description: Test whether the compilation phase for tensor inputs twice are the same of ms_function.
Expectation: The phases are the same only when the tensor inputs are set mutable.
"""
@ms_function
def fn(x, y):
out = P.MatMul()(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
ms_create_time = int(time.time() * 1e9)
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
# The ms_function makes the tensor inputs mutable by default
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
assert phase1 != phase2
# mutable api
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
assert phase1 == phase2
# set_mutable api of Tensor
x.set_const_arg(False)
y.set_const_arg(False)
p.set_const_arg(False)
q.set_const_arg(False)
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
assert phase1 == phase2
def test_tensor_compile_phase2():
"""
Feature: Set mutable tensor input to constant.
Description: Test whether the compilation phase for constant tensor inputs twice are the same.
Expectation: The phases are the same only when the tensor inputs are set mutable.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 == phase2
# Set const arg.
x.set_const_arg()
y.set_const_arg()
p.set_const_arg()
q.set_const_arg()
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 != phase2
# mutable api
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
assert phase1 == phase2
def test_ms_function_tensor_compile_phase2():
"""
Feature: Set mutable tensor input to constant.
Description: Test whether the compilation phase for constant tensor inputs twice are the same of ms_function.
Expectation: The phases are the same only when the tensor inputs are set mutable.
"""
@ms_function
def fn(x, y):
out = P.MatMul()(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
ms_create_time = int(time.time() * 1e9)
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
assert phase1 == phase2
# Set const arg.
x.set_const_arg()
y.set_const_arg()
p.set_const_arg()
q.set_const_arg()
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
assert phase1 != phase2
# mutable api
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
assert phase1 == phase2
def test_grad_constant_tensor():
"""
Feature: Set mutable tensor input to constant.
Description: Get gradient with respect to the constant tensor input.
Expectation: Get an empty gradient.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
grad_net = GradNetWrtX(Net())
output = grad_net(x, y)
assert isinstance(output, tuple)
assert output == ()
def test_ms_function_grad_constant_tensor():
"""
Feature: Set mutable tensor input to constant.
Description: Get gradient with respect to the constant tensor input of ms_function.
Expectation: Get an empty gradient.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
@ms_function
def fn(x, y):
net = Net()
grad_op = GradOperation()
return grad_op(net)(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
output = fn(x, y)
assert isinstance(output, tuple)
assert output == ()
def test_tensor_constant_folding():
"""
Feature: Set mutable tensor input to constant.
Description: Get result of add operator for two constant tensor by constant folding in frontend.
Expectation: Get a correct result.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.Add()
def construct(self, x, y):
out = self.add(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
net = Net()
output = net(x, y)
expect_output = np.array([[0.51, 0.9, 1.5],
[1.3, 1.5, 2.4]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)
def test_ms_function_tensor_constant_folding():
"""
Feature: Set mutable tensor input to constant.
Description: Get result of add operator of ms_function for two constant tensor by constant folding in frontend.
Expectation: Get a correct result.
"""
@ms_function
def fn(x, y):
return P.Add()(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
output = fn(x, y)
expect_output = np.array([[0.51, 0.9, 1.5],
[1.3, 1.5, 2.4]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect_output)
def test_constant_tensor_if():
"""
Feature: Set mutable tensor input to constant.
Description: Get result of control flow with if for constant tensor.
Expectation: Get the correct result.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.z = Tensor([3], dtype=mstype.int32)
def construct(self, x, y):
out = y
if x < self.z:
out = out + y
return out
x = Tensor([0], dtype=mstype.int32, const_arg=True)
y = Tensor([1], dtype=mstype.int32, const_arg=True)
net = Net()
output = net(x, y)
expect_output = np.array([2]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_output)
def test_ms_function_constant_tensor_if():
"""
Feature: Set mutable tensor input to constant.
Description: Get result of control flow with if of ms_function for constant tensor.
Expectation: Get the correct result.
"""
@ms_function
def fn(x, y):
z = Tensor([3], dtype=mstype.int32)
out = y
if x < z:
out = out + y
return out
x = Tensor([0], dtype=mstype.int32, const_arg=True)
y = Tensor([1], dtype=mstype.int32, const_arg=True)
output = fn(x, y)
expect_output = np.array([2]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_output)
def test_check_mutable_value():
"""
Feature: Set mutable tensor input to constant.
Description: Check the illegal arg.
Expectation: Raise the correct error log.
"""
try:
x = Tensor([0], dtype=mstype.int32, const_arg=1)
except TypeError as e:
assert str(e) == "For 'Tensor', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."
try:
x = Tensor([0], dtype=mstype.int32)
x.set_const_arg(1)
except TypeError as e:
assert str(e) == "For 'set_const_arg', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test mutable""" """test mutable"""
import numpy as np import numpy as np
import pytest import pytest
from mindspore.ops.composite import GradOperation from mindspore.ops.composite import GradOperation
@ -23,6 +22,7 @@ from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore._c_expression import Tensor as Tensor_
from mindspore import Parameter from mindspore import Parameter
@ -243,6 +243,49 @@ def test_dict_inputs_compile_phase():
assert phase1 == phase2 assert phase1 == phase2
def test_tensor_inputs_compile_phase():
"""
Feature: Set Constants mutable.
Description: Test whether the compilation phase for Tensor input twice are the same.
Expectation: The phases are the same.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
# tuple of Tensor
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 == phase2
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
assert phase1 == phase2
x = Tensor_(x)
y = Tensor_(y)
p = Tensor_(p)
q = Tensor_(q)
phase1, _ = _cell_graph_executor.compile(net, x, y)
phase2, _ = _cell_graph_executor.compile(net, p, q)
assert phase1 == phase2
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
assert phase1 == phase2
def test_check_mutable_value(): def test_check_mutable_value():
""" """
Feature: Set Constants mutable. Feature: Set Constants mutable.