!33066 Add docs and ut for Variable api
Merge pull request !33066 from YuJianfeng/mutable
This commit is contained in:
commit
7e75db0979
|
@ -15,7 +15,7 @@
|
|||
"""Variable class for setting constants mutable."""
|
||||
|
||||
from .._c_expression import Variable_
|
||||
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
|
||||
class Variable(Variable_):
|
||||
|
@ -26,11 +26,17 @@ class Variable(Variable_):
|
|||
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
|
||||
thing is that it is differentiable from now on.
|
||||
|
||||
Besides, currently when the network input is tuple[Tensor], list[Tensor] or dict[Tensor], if the value of tensor is
|
||||
changed without changing the shape and dtype, the network will be re-compiled because the these inputs are regarded
|
||||
as constant values. Now we can avoid this problem by using 'Variable' to store these inputs.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change or deletion.
|
||||
- This is an experimental prototype that is subject to change or deletion.
|
||||
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
|
||||
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
|
||||
|
||||
Args:
|
||||
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
|
||||
value (Union[tuple[Tensor], list[Tensor], dict[Tensor]]): The value to be stored.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
|
@ -60,10 +66,10 @@ class Variable(Variable_):
|
|||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
|
||||
if isinstance(value, Tensor) or not self._check_all_tensor(value):
|
||||
raise TypeError(
|
||||
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
|
||||
f"CSRTensor), but got {type(value).__name__}")
|
||||
f"For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) "
|
||||
f"or their nested structures, but got {value}")
|
||||
Variable_.__init__(self, value)
|
||||
self._value = value
|
||||
|
||||
|
@ -74,3 +80,17 @@ class Variable(Variable_):
|
|||
@value.setter
|
||||
def value(self, value):
|
||||
self._value = value
|
||||
|
||||
def _check_all_tensor(self, value):
|
||||
"""Check if all the elements are Tensor."""
|
||||
if isinstance(value, (tuple, list)):
|
||||
for element in value:
|
||||
if not self._check_all_tensor(element):
|
||||
return False
|
||||
return True
|
||||
if isinstance(value, dict):
|
||||
for element in value.values():
|
||||
if not self._check_all_tensor(element):
|
||||
return False
|
||||
return True
|
||||
return isinstance(value, Tensor)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""test variable"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.common.variable import Variable
|
||||
from mindspore.common.api import _CellGraphExecutor
|
||||
|
@ -25,6 +26,7 @@ from mindspore import Tensor
|
|||
from mindspore import Parameter
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No runtime support")
|
||||
def test_variable_scalar_mul_grad_first():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
@ -51,6 +53,7 @@ def test_variable_scalar_mul_grad_first():
|
|||
assert output == 3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No runtime support")
|
||||
def test_variable_scalar_mul_grad_all():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
@ -78,6 +81,7 @@ def test_variable_scalar_mul_grad_all():
|
|||
assert output == (3, 2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No runtime support")
|
||||
def test_variable_tuple_or_list_scalar_mul_grad():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
@ -108,6 +112,7 @@ def test_variable_tuple_or_list_scalar_mul_grad():
|
|||
assert output == (3, 2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No runtime support")
|
||||
def test_variable_dict_scalar_mul_grad():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
@ -134,6 +139,7 @@ def test_variable_dict_scalar_mul_grad():
|
|||
assert output == (3, 2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No runtime support")
|
||||
def test_variable_mix_scalar_mul_grad_all():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
|
@ -164,7 +170,7 @@ def test_variable_mix_scalar_mul_grad_all():
|
|||
def test_tuple_inputs_compile_phase():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Test whether the compilation phase for tuple input twice are the same.
|
||||
Description: Test whether the compilation phase for tuple(Tensor) input twice are the same.
|
||||
Expectation: The phases are the same.
|
||||
"""
|
||||
|
||||
|
@ -187,9 +193,83 @@ def test_tuple_inputs_compile_phase():
|
|||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
# tuple of Tensor
|
||||
phase1, _ = _cell_graph_executor.compile(net, (x, y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, (p, q))
|
||||
assert phase1 != phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
|
||||
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
|
||||
assert phase1 == phase2
|
||||
# list of Tensor
|
||||
phase1, _ = _cell_graph_executor.compile(net, [x, y])
|
||||
phase2, _ = _cell_graph_executor.compile(net, [p, q])
|
||||
assert phase1 != phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, Variable([x, y]))
|
||||
phase2, _ = _cell_graph_executor.compile(net, Variable([p, q]))
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_dict_inputs_compile_phase():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Test whether the compilation phase for dict(Tensor) input twice are the same.
|
||||
Expectation: The phases are the same.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, tuple_input):
|
||||
x = tuple_input['a']
|
||||
y = tuple_input['b']
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
phase1, _ = _cell_graph_executor.compile(net, {'a': x, 'b': y})
|
||||
phase2, _ = _cell_graph_executor.compile(net, {'a': p, 'b': q})
|
||||
assert phase1 != phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, Variable({'a': x, 'b': y}))
|
||||
phase2, _ = _cell_graph_executor.compile(net, Variable({'a': p, 'b': q}))
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
def test_check_variable_value():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Check the illegal variable value.
|
||||
Expectation: Raise the correct error log.
|
||||
"""
|
||||
|
||||
try:
|
||||
Variable(1)
|
||||
except TypeError as e:
|
||||
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
|
||||
"their nested structures, but got" in str(e)
|
||||
|
||||
try:
|
||||
Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), {'a': 2}))
|
||||
except TypeError as e:
|
||||
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
|
||||
"their nested structures, but got" in str(e)
|
||||
|
||||
try:
|
||||
Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), (2,)])
|
||||
except TypeError as e:
|
||||
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
|
||||
"their nested structures, but got" in str(e)
|
||||
|
||||
try:
|
||||
Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), 'b': (2,)})
|
||||
except TypeError as e:
|
||||
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
|
||||
"their nested structures, but got" in str(e)
|
||||
|
|
Loading…
Reference in New Issue