Reduce data copy for Parameter init

This commit is contained in:
He Wei 2022-06-13 19:20:21 +08:00
parent 37ae36866c
commit ab3d34aa2f
2 changed files with 21 additions and 19 deletions

View File

@ -15,6 +15,7 @@
"""Parameter for cell."""
from copy import copy
import sys
import numbers
import numpy as np
from mindspore import log as logger
@ -172,7 +173,8 @@ class Parameter(Tensor_):
def __new__(cls, default_input, *args, **kwargs):
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
rc = sys.getrefcount(default_input)
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc)
new_type = Parameter._get_base_class(input_class)
obj = input_class.__new__(new_type)
input_class.__init__(obj, *class_init_args)
@ -233,6 +235,16 @@ class Parameter(Tensor_):
new_obj._inited_param = self._inited_param
return new_obj
def __str__(self):
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
f'requires_grad={self.requires_grad})'
def __repr__(self):
return self.__str__()
def __parameter__(self):
"""For parse check."""
@staticmethod
def _get_base_class(input_class):
input_class_name = Parameter.__name__
@ -251,12 +263,16 @@ class Parameter(Tensor_):
return False
@staticmethod
def _get_parameter_new_args(data):
def _get_parameter_new_args(data, rc):
"""Set `set_data` of current `Parameter`."""
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
if not data.has_init:
if rc == 4:
# when ref count is 4, means the input data is not referenced
# in other place, so we can make a Tensor without copy data.
return (Tensor, data)
# make a copy of Tensor to init the parameter.
return (Tensor, data.asnumpy())
@ -272,16 +288,6 @@ class Parameter(Tensor_):
return (Tensor, data, mstype.float32)
return (Tensor, data)
def __str__(self):
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
f'requires_grad={self.requires_grad})'
def __repr__(self):
return self.__str__()
def __parameter__(self):
"""For parse check."""
def set_param_ps(self, init_in_server=False):
"""
Set whether the trainable parameter is updated by parameter server and whether the
@ -522,7 +528,6 @@ class Parameter(Tensor_):
raise TypeError("The argument `key` must be int type.")
self.param_info.key = value
@property
def requires_grad(self):
"""

View File

@ -1687,8 +1687,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('col2im')(self, output_size, kernel_size, dilation, padding_value, stride)
def reshape(self, *shape):
"""
Give a new shape to a tensor without changing its data.
@ -2790,7 +2788,7 @@ class Tensor(Tensor_):
shape = self.shape
try:
arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
data = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
except ValueError:
msg = "Error shape={}".format(shape)
logger.critical(msg)
@ -2827,14 +2825,13 @@ class Tensor(Tensor_):
self.init.seed, _ = self.seed
with seed_context(self.init):
self.init(arr)
data = np.array(arr)
self.init(data)
if opt_shard_group:
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
self.init = None
self.assign_value(Tensor(data, dtype=self.dtype))
self.assign_value(Tensor_.from_numpy(data))
return self
def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):