forked from mindspore-Ecosystem/mindspore
Allow tensor to be set const for network argument
This commit is contained in:
parent
47b9fd0a42
commit
06510b0649
|
@ -0,0 +1,15 @@
|
|||
mindspore.Tensor.set_const_arg
|
||||
==============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.set_const_arg(const_arg=True)
|
||||
|
||||
指定该Tensor在作为网络入参时是否是一个常量。
|
||||
|
||||
参数:
|
||||
- **const_arg** (bool) - Tensor在作为网络入参时是否是一个常量。默认值:True。
|
||||
|
||||
返回:
|
||||
Tensor,被指定了是否是一个常量网络入参。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果`const_arg`不是一个布尔值。
|
|
@ -11,6 +11,7 @@ mindspore.Tensor
|
|||
- **shape** (Union[tuple, list, int]) - 用于定义该Tensor的形状。如果指定了 `input_data` ,则无需设置该参数。默认值:None。
|
||||
- **init** (Initializer) - 用于在并行模式中延迟Tensor的数据的初始化,如果指定该参数,则 `dtype` 和 `shape` 也必须被指定。不推荐在非自动并行之外的场景下使用该接口。只有当调用 `Tensor.init_data` 时,才会使用指定的 `init` 来初始化Tensor数据。默认值:None。
|
||||
- **internal** (bool) - Tensor是否由框架创建。如果为True,表示Tensor是由框架创建的,如果为False,表示Tensor是由用户创建的。默认值:False。
|
||||
- **const_arg** (bool) - 指定该Tensor作为网络输入时是否为常量。默认值:False。
|
||||
|
||||
输出:
|
||||
Tensor。
|
||||
|
@ -251,3 +252,4 @@ Parameter操作方法
|
|||
:nosignatures:
|
||||
|
||||
mindspore.Tensor.flush_from_cache
|
||||
mindspore.Tensor.set_const_arg
|
||||
|
|
|
@ -16,7 +16,6 @@ mindspore.mutable
|
|||
.. warning::
|
||||
- 这是一个实验特性,未来有可能被修改或删除。
|
||||
- 目前运行时暂时不支持处理标量数据流,所以我们目前只支持Tensor、tuple[Tensor]、list[Tensor]或dict[Tensor]作为输入,主要解决重复编译的问题。
|
||||
- Tensor默认就是可变的,当 `input_data` 为Tensor时,我们不做任何处理直接返回原Tensor。
|
||||
- 当前暂时只支持在网络外部使用该接口。
|
||||
- 当前该接口只在图模式下生效。
|
||||
|
||||
|
|
|
@ -464,8 +464,7 @@ bool EnableGradForScalar(const AbstractBasePtr &abs) {
|
|||
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
|
||||
((*tuple_arg)[pos]->isa<AbstractUndetermined>() || (*tuple_arg)[pos]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*tuple_arg)[pos]));
|
||||
((*tuple_arg)[pos]->BuildValue() == kAnyValue || EnableGradForScalar((*tuple_arg)[pos]));
|
||||
}
|
||||
|
||||
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,
|
||||
|
|
|
@ -102,8 +102,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph, const std::vector<
|
|||
|
||||
AbstractBasePtr param_abs = param_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(param_abs);
|
||||
if (param_abs->isa<abstract::AbstractUndetermined>() || param_abs->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
|
||||
if (param_abs->BuildValue() == kAnyValue || EnableGradForScalar(param_abs) || EnableTupleBroaden(param_abs)) {
|
||||
new_paras.push_back(param_node);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";
|
||||
|
|
|
@ -233,6 +233,11 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
|||
BroadenCNodeAbstract(resolved_graph);
|
||||
}
|
||||
|
||||
bool HasConstArgAttr(const py::object &obj) {
|
||||
constexpr char const_arg_attr[] = "const_arg";
|
||||
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, const FuncGraphPtr &func_graph) {
|
||||
// When the cell is set recomputed, it should not use old scope from cache.
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
|
@ -253,6 +258,10 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &
|
|||
AnfNodePtr output = NewValueNode(convert_result);
|
||||
if (convert_result->isa<tensor::Tensor>()) {
|
||||
output = GetMixedPrecisionCastHelp(func_graph, output);
|
||||
if (HasConstArgAttr(obj)) {
|
||||
MS_LOG(WARNING) << "The tensor " << convert_result->ToString()
|
||||
<< " which is not used for network input argument should not be set const.";
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
|
|
@ -129,11 +129,30 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
|
|||
return true;
|
||||
}
|
||||
|
||||
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false, bool set_mutable = false) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
bool broaden = value->isa<MetaTensor>() || set_mutable || value->isa<MetaSparseTensor>() ||
|
||||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
||||
bool Mutable(const py::object &obj) {
|
||||
constexpr char mutable_attr[] = "__ms_mutable__";
|
||||
return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
|
||||
}
|
||||
|
||||
bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
|
||||
if (!value->isa<MetaTensor>()) {
|
||||
return false;
|
||||
}
|
||||
constexpr char const_arg_attr[] = "const_arg";
|
||||
return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||
}
|
||||
|
||||
bool EnableTupleBroaden(const ValuePtr &value, bool enable_tuple_broaden) {
|
||||
return enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>());
|
||||
}
|
||||
|
||||
bool GradForScalar(const ValuePtr &value) {
|
||||
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>();
|
||||
}
|
||||
|
||||
AbstractBasePtr ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden = false) {
|
||||
bool broaden = TensorArgMutable(arg, value) || Mutable(arg) || value->isa<MetaSparseTensor>() ||
|
||||
EnableTupleBroaden(value, enable_tuple_broaden) || GradForScalar(value);
|
||||
|
||||
return abstract::FromValue(value, broaden);
|
||||
}
|
||||
|
@ -208,33 +227,6 @@ void RecordInitStatus() {
|
|||
|
||||
void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
|
||||
|
||||
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto abs_sequence = abs->cast_ptr<abstract::AbstractSequence>();
|
||||
if (abs_sequence != nullptr) {
|
||||
const auto &elements = abs_sequence->elements();
|
||||
for (auto &ele : elements) {
|
||||
SetValueMutable(ele);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto abs_dict = abs->cast_ptr<abstract::AbstractDictionary>();
|
||||
if (abs_dict != nullptr) {
|
||||
const auto &elements = abs_dict->elements();
|
||||
for (auto &ele : elements) {
|
||||
SetValueMutable(ele.second);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
abs->set_value_mutable(true);
|
||||
}
|
||||
|
||||
std::string ToOrdinal(const size_t &i) {
|
||||
auto suffix = "th";
|
||||
if (i == kIndex1) {
|
||||
|
@ -291,13 +283,7 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::tuple &args, bool ena
|
|||
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
|
||||
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
||||
}
|
||||
constexpr char mutable_attr[] = "__ms_mutable__";
|
||||
bool set_mutable = false;
|
||||
if (py::hasattr(args[i], mutable_attr) && py::cast<bool>(py::getattr(args[i], mutable_attr))) {
|
||||
SetValueMutable(converted->ToAbstract());
|
||||
set_mutable = true;
|
||||
}
|
||||
AbstractBasePtr abs = ArgsToAbstract(converted, enable_tuple_broaden, set_mutable);
|
||||
AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
|
||||
(void)args_abs.emplace_back(abs);
|
||||
// The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
|
||||
// so we keep all inputs for subsequent procedure.
|
||||
|
@ -880,7 +866,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
||||
}
|
||||
(void)arguments.emplace_back(converted);
|
||||
auto args_abstract_item = ArgsToAbstract(converted, enable_tuple_broaden_);
|
||||
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
|
||||
if (is_auto_parallel) {
|
||||
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
|
||||
}
|
||||
|
|
|
@ -2134,6 +2134,10 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(converted_val);
|
||||
if (converted_val->isa<tensor::Tensor>() && HasConstArgAttr(obj)) {
|
||||
MS_LOG(WARNING) << "The tensor " << converted_val->ToString()
|
||||
<< " which is not used for network input argument should not be set const.";
|
||||
}
|
||||
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
|
@ -2248,6 +2252,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
});
|
||||
return std::make_shared<AbstractDictionary>(kv);
|
||||
}
|
||||
|
||||
bool HasConstArgAttr(const py::object &obj) {
|
||||
constexpr char const_arg_attr[] = "const_arg";
|
||||
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||
}
|
||||
};
|
||||
|
||||
class PartialEvaluator : public Evaluator {
|
||||
|
|
|
@ -221,6 +221,27 @@ def _restore_mutable_attr(args_list, compile_args):
|
|||
return new_compile_args
|
||||
|
||||
|
||||
def _get_args_for_run(obj, args_list):
|
||||
"""Get the actual input args for runtime."""
|
||||
inputs = []
|
||||
for i in args_list:
|
||||
if isinstance(i, PythonTensor):
|
||||
if i.has_init:
|
||||
i.init_data()
|
||||
if not i.const_arg:
|
||||
inputs.append(i)
|
||||
elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||
inputs.append(i)
|
||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
||||
inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
inputs.append(i)
|
||||
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||
_check_all_tensor(i):
|
||||
inputs.append(i)
|
||||
return inputs
|
||||
|
||||
|
||||
class _MindsporeFunctionExecutor:
|
||||
"""
|
||||
Represents a function compiled by graph compiler.
|
||||
|
@ -443,17 +464,7 @@ class _MindsporeFunctionExecutor:
|
|||
Returns:
|
||||
new_inputs, new input args, which are required for running.
|
||||
"""
|
||||
new_inputs = []
|
||||
for i in args_list:
|
||||
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||
new_inputs.append(i)
|
||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
elif self.enable_tuple_broaden and isinstance(i, tuple) and _check_all_tensor(i):
|
||||
new_inputs.append(i)
|
||||
return new_inputs
|
||||
return _get_args_for_run(self, args_list)
|
||||
|
||||
|
||||
# The attributes used to identify a given object.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
||||
|
||||
class _Tuple(tuple):
|
||||
|
@ -42,7 +43,7 @@ def _check_all_tensor(value):
|
|||
if not _check_all_tensor(element):
|
||||
return False
|
||||
return True
|
||||
return isinstance(value, Tensor)
|
||||
return isinstance(value, Tensor_)
|
||||
|
||||
|
||||
def mutable(input_data):
|
||||
|
@ -67,8 +68,6 @@ def mutable(input_data):
|
|||
- This is an experimental prototype that is subject to change or deletion.
|
||||
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
|
||||
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
|
||||
- Tensor is mutable by default, when the `input_data` is Tensor, we just return the origin Tensor and nothing
|
||||
is done.
|
||||
- Currently we only support to use this api outside the network temporarily.
|
||||
- Currently this api only works in GRAPH mode.
|
||||
|
||||
|
@ -122,9 +121,6 @@ def mutable(input_data):
|
|||
[ 1.50000000e+00, 1.50000000e+00, 1.50000000e+00]]))
|
||||
"""
|
||||
|
||||
if isinstance(input_data, Tensor):
|
||||
return input_data
|
||||
|
||||
if not _check_all_tensor(input_data):
|
||||
raise TypeError(
|
||||
f"For 'mutable', the 'input_data' should be one of (Tensor, tuple[Tensor], list[Tensor], dict[Tensor]) "
|
||||
|
@ -137,6 +133,11 @@ def mutable(input_data):
|
|||
ret = _Tuple(input_data)
|
||||
elif isinstance(input_data, dict):
|
||||
ret = _Dict(input_data)
|
||||
elif isinstance(input_data, Tensor):
|
||||
ret.set_const_arg(False)
|
||||
elif isinstance(input_data, Tensor_):
|
||||
ret = Tensor(input_data, internal=True)
|
||||
ret.set_const_arg(False)
|
||||
|
||||
setattr(ret, "__ms_mutable__", True)
|
||||
return ret
|
||||
|
|
|
@ -60,6 +60,8 @@ class Tensor(Tensor_):
|
|||
'True' means that the tensor is created by framework.
|
||||
'False' means that the tensor is created by user.
|
||||
Default: False
|
||||
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
||||
Default: False.
|
||||
|
||||
Outputs:
|
||||
Tensor.
|
||||
|
@ -116,7 +118,7 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
delta_seed = 0
|
||||
|
||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False):
|
||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
|
||||
self.init_finished = False
|
||||
if internal:
|
||||
Tensor_.__init__(self, input_data)
|
||||
|
@ -166,6 +168,8 @@ class Tensor(Tensor_):
|
|||
else:
|
||||
Tensor_.__init__(self, input_data)
|
||||
|
||||
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
||||
self.const_arg = const_arg
|
||||
self.virtual_flag = False
|
||||
self.init = init
|
||||
self.init_finished = True
|
||||
|
@ -191,6 +195,7 @@ class Tensor(Tensor_):
|
|||
new_obj = Tensor(self)
|
||||
new_obj.init = self.init
|
||||
new_obj.virtual_flag = self.virtual_flag
|
||||
new_obj.const_arg = self.const_arg
|
||||
return new_obj
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -445,6 +450,33 @@ class Tensor(Tensor_):
|
|||
|
||||
return Tensor(Tensor_.from_numpy(array))
|
||||
|
||||
def set_const_arg(self, const_arg=True):
|
||||
"""
|
||||
Specify whether the tensor is a constant when it is used for the argument of a network.
|
||||
|
||||
Args:
|
||||
const_arg (bool): Whether the tensor is a constant when it is used for the argument of a network.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, has been specified whether to be a const network argument.
|
||||
|
||||
Raises:
|
||||
TypeError: If `const_arg` is not a bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||
>>> x.set_const_arg(True)
|
||||
"""
|
||||
validator.check_value_type('const_arg', const_arg, bool, 'set_const_arg')
|
||||
self.const_arg = const_arg
|
||||
return self
|
||||
|
||||
def assign_value(self, value):
|
||||
"""
|
||||
Assign another tensor value to this tensor.
|
||||
|
|
|
@ -34,9 +34,9 @@ from mindspore import context
|
|||
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
||||
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor, CSRTensor, COOTensor
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations import Cast
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -965,22 +965,7 @@ class Cell(Cell_):
|
|||
self._auto_parallel_compile_and_run = True
|
||||
self.compile(*inputs)
|
||||
|
||||
new_inputs = []
|
||||
for i in inputs:
|
||||
if isinstance(i, Tensor):
|
||||
if i.has_init:
|
||||
i.init_data()
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, (CSRTensor, COOTensor)):
|
||||
new_inputs.append(i)
|
||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||
_check_all_tensor(i):
|
||||
new_inputs.append(i)
|
||||
|
||||
new_inputs = _get_args_for_run(self, inputs)
|
||||
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
|
||||
|
||||
def auto_parallel_compile_and_run(self):
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test mutable or constant tensor feature"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import mutable
|
||||
from mindspore import ms_function
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cal_constant_tensor():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get the matmul result for two constant tensor.
|
||||
Expectation: Get the correct result.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
net = Net()
|
||||
output = net(x, y)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
expect_output = net(p, q)
|
||||
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cal_constant_tensor_ms_function():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get the matmul result for two constant tensor in ms_function.
|
||||
Expectation: Get the correct result.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def net(x, y):
|
||||
out = P.MatMul()(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
output = net(x, y)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
expect_output = net(p, q)
|
||||
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_const_arg_tensor_to_mutable():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get gradient with respect to constant tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
grad_net = GradNetWrtX(Net())
|
||||
# mutable api
|
||||
output = grad_net(mutable(x), y)
|
||||
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
# tensor set_const_arg api
|
||||
x.set_const_arg(False)
|
||||
output = grad_net(x, y)
|
||||
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_function_grad_const_arg_tensor_to_mutable():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get gradient with respect to constant tensor input for ms_function.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
net = Net()
|
||||
grad_op = GradOperation()
|
||||
return grad_op(net)(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
# mutable api
|
||||
output = fn(mutable(x), y)
|
||||
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
# tensor set_const_arg api
|
||||
x.set_const_arg(False)
|
||||
output = fn(x, y)
|
||||
expect_output = np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test const tensor for network arg"""
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.common import mutable
|
||||
from mindspore.common.api import _CellGraphExecutor, _MindsporeFunctionExecutor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
|
||||
|
||||
def test_tensor_compile_phase1():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Test whether the compilation phase for tensor inputs twice are the same.
|
||||
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
# Init the tensors as const arguments.
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
# set_mutable api of Tensor
|
||||
x.set_const_arg(False)
|
||||
y.set_const_arg(False)
|
||||
p.set_const_arg(False)
|
||||
q.set_const_arg(False)
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_ms_function_tensor_compile_phase1():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Test whether the compilation phase for tensor inputs twice are the same of ms_function.
|
||||
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
out = P.MatMul()(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32, const_arg=True)
|
||||
ms_create_time = int(time.time() * 1e9)
|
||||
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||
# The ms_function makes the tensor inputs mutable by default
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||
assert phase1 == phase2
|
||||
# set_mutable api of Tensor
|
||||
x.set_const_arg(False)
|
||||
y.set_const_arg(False)
|
||||
p.set_const_arg(False)
|
||||
q.set_const_arg(False)
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_tensor_compile_phase2():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Test whether the compilation phase for constant tensor inputs twice are the same.
|
||||
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 == phase2
|
||||
# Set const arg.
|
||||
x.set_const_arg()
|
||||
y.set_const_arg()
|
||||
p.set_const_arg()
|
||||
q.set_const_arg()
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_ms_function_tensor_compile_phase2():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Test whether the compilation phase for constant tensor inputs twice are the same of ms_function.
|
||||
Expectation: The phases are the same only when the tensor inputs are set mutable.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
out = P.MatMul()(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
ms_create_time = int(time.time() * 1e9)
|
||||
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
assert phase1 == phase2
|
||||
# Set const arg.
|
||||
x.set_const_arg()
|
||||
y.set_const_arg()
|
||||
p.set_const_arg()
|
||||
q.set_const_arg()
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_grad_constant_tensor():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get gradient with respect to the constant tensor input.
|
||||
Expectation: Get an empty gradient.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
grad_net = GradNetWrtX(Net())
|
||||
output = grad_net(x, y)
|
||||
assert isinstance(output, tuple)
|
||||
assert output == ()
|
||||
|
||||
|
||||
def test_ms_function_grad_constant_tensor():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get gradient with respect to the constant tensor input of ms_function.
|
||||
Expectation: Get an empty gradient.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
net = Net()
|
||||
grad_op = GradOperation()
|
||||
return grad_op(net)(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
output = fn(x, y)
|
||||
assert isinstance(output, tuple)
|
||||
assert output == ()
|
||||
|
||||
|
||||
def test_tensor_constant_folding():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get result of add operator for two constant tensor by constant folding in frontend.
|
||||
Expectation: Get a correct result.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.add(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
|
||||
net = Net()
|
||||
output = net(x, y)
|
||||
expect_output = np.array([[0.51, 0.9, 1.5],
|
||||
[1.3, 1.5, 2.4]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
def test_ms_function_tensor_constant_folding():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get result of add operator of ms_function for two constant tensor by constant folding in frontend.
|
||||
Expectation: Get a correct result.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
return P.Add()(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32, const_arg=True)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3]], dtype=mstype.float32, const_arg=True)
|
||||
output = fn(x, y)
|
||||
expect_output = np.array([[0.51, 0.9, 1.5],
|
||||
[1.3, 1.5, 2.4]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
def test_constant_tensor_if():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get result of control flow with if for constant tensor.
|
||||
Expectation: Get the correct result.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.z = Tensor([3], dtype=mstype.int32)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = y
|
||||
if x < self.z:
|
||||
out = out + y
|
||||
return out
|
||||
|
||||
x = Tensor([0], dtype=mstype.int32, const_arg=True)
|
||||
y = Tensor([1], dtype=mstype.int32, const_arg=True)
|
||||
net = Net()
|
||||
output = net(x, y)
|
||||
expect_output = np.array([2]).astype(np.int32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
def test_ms_function_constant_tensor_if():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get result of control flow with if of ms_function for constant tensor.
|
||||
Expectation: Get the correct result.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def fn(x, y):
|
||||
z = Tensor([3], dtype=mstype.int32)
|
||||
out = y
|
||||
if x < z:
|
||||
out = out + y
|
||||
return out
|
||||
|
||||
x = Tensor([0], dtype=mstype.int32, const_arg=True)
|
||||
y = Tensor([1], dtype=mstype.int32, const_arg=True)
|
||||
output = fn(x, y)
|
||||
expect_output = np.array([2]).astype(np.int32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
def test_check_mutable_value():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Check the illegal arg.
|
||||
Expectation: Raise the correct error log.
|
||||
"""
|
||||
try:
|
||||
x = Tensor([0], dtype=mstype.int32, const_arg=1)
|
||||
except TypeError as e:
|
||||
assert str(e) == "For 'Tensor', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."
|
||||
|
||||
try:
|
||||
x = Tensor([0], dtype=mstype.int32)
|
||||
x.set_const_arg(1)
|
||||
except TypeError as e:
|
||||
assert str(e) == "For 'set_const_arg', the type of 'const_arg' should be 'bool', but got '1' with type 'int'."
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test mutable"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
@ -23,6 +22,7 @@ from mindspore.ops import operations as P
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore import Parameter
|
||||
|
||||
|
||||
|
@ -243,6 +243,49 @@ def test_dict_inputs_compile_phase():
|
|||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_tensor_inputs_compile_phase():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Test whether the compilation phase for Tensor input twice are the same.
|
||||
Expectation: The phases are the same.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
# tuple of Tensor
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 == phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
x = Tensor_(x)
|
||||
y = Tensor_(y)
|
||||
p = Tensor_(p)
|
||||
q = Tensor_(q)
|
||||
phase1, _ = _cell_graph_executor.compile(net, x, y)
|
||||
phase2, _ = _cell_graph_executor.compile(net, p, q)
|
||||
assert phase1 == phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, mutable(x), mutable(y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_check_mutable_value():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
|
Loading…
Reference in New Issue