commit
fd6dc1b060
|
@ -17,13 +17,13 @@ from . import dtype
|
|||
from .api import ms_function
|
||||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
|
||||
from .tensor import Tensor, RowTensor, SparseTensor
|
||||
from .seed import set_seed, get_seed
|
||||
|
||||
|
||||
__all__ = dtype.__all__
|
||||
__all__.extend([
|
||||
"MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
"Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype",
|
||||
|
|
|
@ -100,7 +100,6 @@ class Parameter(Tensor_):
|
|||
... def construct(self, x):
|
||||
... out = self.matmul(self.weight, x)
|
||||
... return out
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
>>> net = Net()
|
||||
>>> x = Tensor(np.ones((2,1)))
|
||||
>>> print(net(x))
|
||||
|
@ -113,15 +112,14 @@ class Parameter(Tensor_):
|
|||
__base_type__ = {}
|
||||
|
||||
def __new__(cls, default_input, *args, **kwargs):
|
||||
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
|
||||
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
|
||||
new_type = Parameter._get_base_class(input_class)
|
||||
obj = input_class.__new__(new_type)
|
||||
input_class.__init__(obj, *class_init_args)
|
||||
# it's better to make the Initializer a kind of tensor.
|
||||
obj.init_mode = None
|
||||
obj.is_default_input_init = False
|
||||
if isinstance(default_input, Tensor) and default_input.has_init:
|
||||
obj.is_default_input_init = True
|
||||
obj.is_default_input_init = init_data_flag
|
||||
if obj.has_init:
|
||||
obj.init_mode = default_input
|
||||
return obj
|
||||
|
|
|
@ -19,11 +19,10 @@ from mindspore import log as logger
|
|||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
from .._c_expression import MetaTensor
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._checkparam import Validator as validator
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
|
||||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
|
@ -41,6 +40,9 @@ class Tensor(Tensor_):
|
|||
dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`.
|
||||
The argument is used to define the data type of the output tensor. If it is None, the data type of the
|
||||
output tensor will be as same as the `input_data`. Default: None.
|
||||
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
|
||||
output. Default: None.
|
||||
init (class:'Initializer'): the information of init data.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same shape as `input_data`.
|
||||
|
@ -65,6 +67,12 @@ class Tensor(Tensor_):
|
|||
if isinstance(input_data, np_types):
|
||||
input_data = np.array(input_data)
|
||||
|
||||
if input_data is not None and shape is not None and input_data.shape != shape:
|
||||
raise ValueError("input_data.shape and shape should be same.")
|
||||
|
||||
if init is not None and (shape is None or dtype is None):
|
||||
raise ValueError("init, dtype and shape must have values at the same time.")
|
||||
|
||||
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
|
||||
raise TypeError("input_data and init can not be None at the same time.")
|
||||
|
||||
|
@ -306,10 +314,12 @@ class Tensor(Tensor_):
|
|||
|
||||
def asnumpy(self):
|
||||
"""Convert tensor to numpy array."""
|
||||
self.init_check()
|
||||
return Tensor_.asnumpy(self)
|
||||
|
||||
def _flush_from_cache(self):
|
||||
"""Flush cache data to host if tensor is cache enable."""
|
||||
self.init_check()
|
||||
Tensor_._flush_from_cache(self)
|
||||
|
||||
def all(self, axis=(), keep_dims=False):
|
||||
|
@ -327,6 +337,7 @@ class Tensor(Tensor_):
|
|||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
self.init_check()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
|
||||
|
@ -346,6 +357,7 @@ class Tensor(Tensor_):
|
|||
Tensor, has the same data type as x.
|
||||
"""
|
||||
|
||||
self.init_check()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
|
||||
|
@ -360,6 +372,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same dimension as the input shape.
|
||||
"""
|
||||
self.init_check()
|
||||
if not shape:
|
||||
raise ValueError("The shape variable should not be empty")
|
||||
if isinstance(shape[0], tuple):
|
||||
|
@ -379,6 +392,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same dimension as input tensor.
|
||||
"""
|
||||
self.init_check()
|
||||
return tensor_operator_registry.get('broadcast_to')(x.shape)(self)
|
||||
|
||||
def abs(self):
|
||||
|
@ -388,6 +402,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
self.init_check()
|
||||
return tensor_operator_registry.get('abs')()(self)
|
||||
|
||||
def mean(self, axis=(), keep_dims=False):
|
||||
|
@ -404,6 +419,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
self.init_check()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
|
||||
|
@ -429,6 +445,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same dimension as input tensor, with axes suitably permuted.
|
||||
"""
|
||||
self.init_check()
|
||||
perm = validator.check_transpose_axis(axes, self.ndim)
|
||||
return tensor_operator_registry.get('transpose')()(self, perm)
|
||||
|
||||
|
@ -446,6 +463,7 @@ class Tensor(Tensor_):
|
|||
reshaped_tensor(Tensor): This will be a new view object if possible;
|
||||
otherwise, it will be a copy.
|
||||
"""
|
||||
self.init_check()
|
||||
new_shape = validator.check_reshape_shp(shape)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
|
@ -457,6 +475,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
self.init_check()
|
||||
reshape_op = tensor_operator_registry.get('reshape')()
|
||||
return reshape_op(self, (-1,))
|
||||
|
||||
|
@ -472,6 +491,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
self.init_check()
|
||||
reshape_op = tensor_operator_registry.get('reshape')()
|
||||
trans_op = tensor_operator_registry.get('transpose')()
|
||||
|
||||
|
@ -493,6 +513,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Transposed tensor, has the same data type as the original tensor x.
|
||||
"""
|
||||
self.init_check()
|
||||
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
|
||||
|
||||
if axis1 == axis2:
|
||||
|
@ -521,6 +542,7 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, with all or a subset of the dimensions of length 1 removed.
|
||||
"""
|
||||
self.init_check()
|
||||
if axis is None:
|
||||
return tensor_operator_registry.get('squeeze')(self)
|
||||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||
|
@ -542,12 +564,18 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, with the designated dtype.
|
||||
"""
|
||||
self.init_check()
|
||||
dtype = validator.check_astype_dtype(dtype)
|
||||
if not copy and dtype == self.dtype:
|
||||
return self
|
||||
return tensor_operator_registry.get('cast')(self, dtype)
|
||||
|
||||
|
||||
def init_check(self):
|
||||
if self.has_init:
|
||||
self.init_data()
|
||||
return self
|
||||
|
||||
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
|
||||
"""
|
||||
Get the tensor format data of this Tensor.
|
||||
|
@ -601,7 +629,9 @@ class Tensor(Tensor_):
|
|||
rank = get_rank(opt_shard_group)
|
||||
size = get_group_size(opt_shard_group)
|
||||
data = np.split(data, size)[rank]
|
||||
return Tensor(data, dtype=self.dtype)
|
||||
self.init = None
|
||||
self.assign_value(Tensor(data, dtype=self.dtype))
|
||||
return self
|
||||
|
||||
|
||||
def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.initializer as init
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -99,6 +100,12 @@ def test_asnumpy():
|
|||
assert a.asnumpy().all() == npd.all()
|
||||
|
||||
|
||||
def test_initializer_asnumpy():
|
||||
npd = np.ones((2, 3))
|
||||
a = init.initializer('one', [2, 3], ms.int32)
|
||||
assert a.asnumpy().all() == npd.all()
|
||||
|
||||
|
||||
def test_print():
|
||||
a = ms.Tensor(np.ones((2, 3)))
|
||||
a.set_dtype(ms.int32)
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test expand_as"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as init
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
@ -34,6 +36,20 @@ def test_expand_as():
|
|||
net()
|
||||
|
||||
|
||||
def test_initializer_expand_as():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.t1 = init.initializer('one', [1, 3], ms.float32)
|
||||
self.t2 = init.initializer('one', [2, 3], ms.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.t1.expand_as(self.t2)
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_expand_as_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
""" test view"""
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as init
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
@ -35,6 +37,19 @@ def test_view():
|
|||
net()
|
||||
|
||||
|
||||
def test_view_initializer():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = init.initializer('normal', [2, 3], ms.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.view(-1)
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_view_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue