!33066 Add docs and ut for Variable api

Merge pull request !33066 from YuJianfeng/mutable
This commit is contained in:
i-robot 2022-04-16 03:26:26 +00:00 committed by Gitee
commit 7e75db0979
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 107 additions and 7 deletions

View File

@ -15,7 +15,7 @@
"""Variable class for setting constants mutable."""
from .._c_expression import Variable_
from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..common.tensor import Tensor
class Variable(Variable_):
@ -26,11 +26,17 @@ class Variable(Variable_):
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
thing is that it is differentiable from now on.
Besides, currently when the network input is tuple[Tensor], list[Tensor] or dict[Tensor], if the value of tensor is
changed without changing the shape and dtype, the network will be re-compiled because the these inputs are regarded
as constant values. Now we can avoid this problem by using 'Variable' to store these inputs.
.. warning::
This is an experimental prototype that is subject to change or deletion.
- This is an experimental prototype that is subject to change or deletion.
- The runtime has not yet supported to handle the scalar data flow. So we only support tuple[Tensor],
list[Tensor] or dict[Tensor] for network input to avoid the re-compiled problem now.
Args:
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
value (Union[tuple[Tensor], list[Tensor], dict[Tensor]]): The value to be stored.
Examples:
>>> import mindspore.nn as nn
@ -60,10 +66,10 @@ class Variable(Variable_):
"""
def __init__(self, value):
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
if isinstance(value, Tensor) or not self._check_all_tensor(value):
raise TypeError(
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
f"CSRTensor), but got {type(value).__name__}")
f"For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) "
f"or their nested structures, but got {value}")
Variable_.__init__(self, value)
self._value = value
@ -74,3 +80,17 @@ class Variable(Variable_):
@value.setter
def value(self, value):
self._value = value
def _check_all_tensor(self, value):
"""Check if all the elements are Tensor."""
if isinstance(value, (tuple, list)):
for element in value:
if not self._check_all_tensor(element):
return False
return True
if isinstance(value, dict):
for element in value.values():
if not self._check_all_tensor(element):
return False
return True
return isinstance(value, Tensor)

View File

@ -15,6 +15,7 @@
"""test variable"""
import numpy as np
import pytest
from mindspore.ops.composite import GradOperation
from mindspore.common.variable import Variable
from mindspore.common.api import _CellGraphExecutor
@ -25,6 +26,7 @@ from mindspore import Tensor
from mindspore import Parameter
@pytest.mark.skip(reason="No runtime support")
def test_variable_scalar_mul_grad_first():
"""
Feature: Set Constants mutable.
@ -51,6 +53,7 @@ def test_variable_scalar_mul_grad_first():
assert output == 3
@pytest.mark.skip(reason="No runtime support")
def test_variable_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
@ -78,6 +81,7 @@ def test_variable_scalar_mul_grad_all():
assert output == (3, 2)
@pytest.mark.skip(reason="No runtime support")
def test_variable_tuple_or_list_scalar_mul_grad():
"""
Feature: Set Constants mutable.
@ -108,6 +112,7 @@ def test_variable_tuple_or_list_scalar_mul_grad():
assert output == (3, 2)
@pytest.mark.skip(reason="No runtime support")
def test_variable_dict_scalar_mul_grad():
"""
Feature: Set Constants mutable.
@ -134,6 +139,7 @@ def test_variable_dict_scalar_mul_grad():
assert output == (3, 2)
@pytest.mark.skip(reason="No runtime support")
def test_variable_mix_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
@ -164,7 +170,7 @@ def test_variable_mix_scalar_mul_grad_all():
def test_tuple_inputs_compile_phase():
"""
Feature: Set Constants mutable.
Description: Test whether the compilation phase for tuple input twice are the same.
Description: Test whether the compilation phase for tuple(Tensor) input twice are the same.
Expectation: The phases are the same.
"""
@ -187,9 +193,83 @@ def test_tuple_inputs_compile_phase():
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
# tuple of Tensor
phase1, _ = _cell_graph_executor.compile(net, (x, y))
phase2, _ = _cell_graph_executor.compile(net, (p, q))
assert phase1 != phase2
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
assert phase1 == phase2
# list of Tensor
phase1, _ = _cell_graph_executor.compile(net, [x, y])
phase2, _ = _cell_graph_executor.compile(net, [p, q])
assert phase1 != phase2
phase1, _ = _cell_graph_executor.compile(net, Variable([x, y]))
phase2, _ = _cell_graph_executor.compile(net, Variable([p, q]))
assert phase1 == phase2
def test_dict_inputs_compile_phase():
"""
Feature: Set Constants mutable.
Description: Test whether the compilation phase for dict(Tensor) input twice are the same.
Expectation: The phases are the same.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, tuple_input):
x = tuple_input['a']
y = tuple_input['b']
x = x * self.z
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
phase1, _ = _cell_graph_executor.compile(net, {'a': x, 'b': y})
phase2, _ = _cell_graph_executor.compile(net, {'a': p, 'b': q})
assert phase1 != phase2
phase1, _ = _cell_graph_executor.compile(net, Variable({'a': x, 'b': y}))
phase2, _ = _cell_graph_executor.compile(net, Variable({'a': p, 'b': q}))
assert phase1 == phase2
def test_check_variable_value():
"""
Feature: Set Constants mutable.
Description: Check the illegal variable value.
Expectation: Raise the correct error log.
"""
try:
Variable(1)
except TypeError as e:
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
"their nested structures, but got" in str(e)
try:
Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), {'a': 2}))
except TypeError as e:
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
"their nested structures, but got" in str(e)
try:
Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), (2,)])
except TypeError as e:
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
"their nested structures, but got" in str(e)
try:
Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), 'b': (2,)})
except TypeError as e:
assert "For 'Varibale', the 'value' should be one of (tuple[Tensor], list[Tensor], dict[Tensor]) or " \
"their nested structures, but got" in str(e)