forked from mindspore-Ecosystem/mindspore
Support weight compile according to shape
Signed-off-by: candanzg <zhangshucheng@huawei.com>
This commit is contained in:
parent
25b2424f9b
commit
2cc85bdc93
|
@ -319,6 +319,10 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
|
|||
std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << std::string(py::str(shape)) << "]";
|
||||
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
|
||||
std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
|
||||
py::tuple shape = m_tensor->GetPyTupleShape();
|
||||
buffer_ << "[" << std::string(py::str(shape)) << "]";
|
||||
}
|
||||
}
|
||||
buffer_ << "</td></tr>";
|
||||
|
|
|
@ -102,6 +102,26 @@ int MetaTensor::DimensionSize(const size_t index) const {
|
|||
return dim_size;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr MetaTensor::ToAbstract() {
|
||||
auto tens = shared_from_base<MetaTensor>();
|
||||
auto dtype = tens->Dtype();
|
||||
if (!IsSubType(dtype, kNumber)) {
|
||||
MS_LOG(EXCEPTION) << "Expect MetaTensor type kNumber but got: " << dtype->ToString() << ".";
|
||||
}
|
||||
auto tensor_shape = tens->shape();
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
||||
abs_tensor->set_value(shared_from_base<MetaTensor>());
|
||||
return abs_tensor;
|
||||
}
|
||||
|
||||
py::tuple MetaTensor::GetPyTupleShape() const {
|
||||
py::tuple dims(shape_.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = py::int_(shape_[i]);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
int MetaTensor::ElementsNum() const {
|
||||
return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int>());
|
||||
}
|
||||
|
@ -197,14 +217,6 @@ int Tensor::DataDim() const { return static_cast<int>(data_.ndim()); }
|
|||
|
||||
int Tensor::DataSize() const { return static_cast<int>(data_.size()); }
|
||||
|
||||
py::tuple Tensor::GetPyTupleShape() const {
|
||||
py::tuple dims(shape_.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = py::int_(shape_[i]);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
py::array Tensor::data() const { return data_; }
|
||||
|
||||
int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
|
||||
|
@ -547,7 +559,10 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
return tensor;
|
||||
}));
|
||||
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
|
||||
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"));
|
||||
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
|
||||
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
|
||||
.def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
|
||||
.def("shape", &MetaTensor::GetPyTupleShape, "Get the MetaTensor's shape.");
|
||||
}));
|
||||
|
||||
} // namespace tensor
|
||||
|
|
|
@ -163,6 +163,8 @@ class MetaTensor : public Value {
|
|||
//
|
||||
// All the types are defined in "ir/dtype.h".
|
||||
TypePtr Dtype() const;
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
py::tuple GetPyTupleShape() const;
|
||||
TypeId data_type() const { return data_type_; }
|
||||
std::string ToString() const override;
|
||||
std::string DumpText() const override;
|
||||
|
@ -230,6 +232,7 @@ class MetaTensor : public Value {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
const bool parse_info_ = true;
|
||||
|
||||
protected:
|
||||
// brief Data type of the tensor.
|
||||
|
@ -348,11 +351,6 @@ class Tensor : public MetaTensor {
|
|||
// return The total number of elements of the tensor data.
|
||||
int DataSize() const;
|
||||
|
||||
// brief Get tensor's shape
|
||||
//
|
||||
// return [py::tuple] The tensor's shape
|
||||
py::tuple GetPyTupleShape() const;
|
||||
|
||||
// brief Tensor's data value.
|
||||
//
|
||||
// return [py::array] The tensor's data in py::array.
|
||||
|
@ -423,6 +421,7 @@ class Tensor : public MetaTensor {
|
|||
};
|
||||
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using MetaTensorPtr = std::shared_ptr<MetaTensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
||||
} // namespace tensor
|
||||
|
|
|
@ -36,6 +36,8 @@ namespace mindspore {
|
|||
namespace parse {
|
||||
using Tensor = mindspore::tensor::Tensor;
|
||||
using TensorPtr = mindspore::tensor::TensorPtr;
|
||||
using MetaTensor = mindspore::tensor::MetaTensor;
|
||||
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
|
||||
|
||||
namespace {
|
||||
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
|
||||
|
@ -181,6 +183,18 @@ bool ConvertDataType(const py::object &obj, ValuePtr *const data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting MetaTensor object.";
|
||||
|
||||
auto m_tensor = obj.cast<MetaTensorPtr>();
|
||||
if (m_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null.";
|
||||
return false;
|
||||
}
|
||||
*data = m_tensor;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting tensor object";
|
||||
|
||||
|
@ -283,6 +297,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
ret = ConvertDataType(obj, &converted);
|
||||
} else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) {
|
||||
ret = ConvertTensor(obj, &converted);
|
||||
} else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) {
|
||||
ret = ConvertMetaTensor(obj, &converted);
|
||||
} else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) {
|
||||
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
|
||||
converted = env;
|
||||
|
|
|
@ -20,6 +20,7 @@ namespace mindspore {
|
|||
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
|
||||
const char PYTHON_METAFUNCGRAPH_FLAG[] = "__metafuncgraph_flag__";
|
||||
const char PYTHON_TENSOR_FLAG[] = "__tensor_flag__";
|
||||
const char PYTHON_META_TENSOR_FLAG[] = "__meta_tensor_flag__";
|
||||
const char PYTHON_ENVINSTANCE_FLAG[] = "__envinstance_flag__";
|
||||
const char PYTHON_DTYPE_FLAG[] = "__dtype_flag__";
|
||||
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
||||
|
|
|
@ -22,6 +22,7 @@ namespace mindspore {
|
|||
extern const char PYTHON_PRIMITIVE_FLAG[];
|
||||
extern const char PYTHON_METAFUNCGRAPH_FLAG[];
|
||||
extern const char PYTHON_TENSOR_FLAG[];
|
||||
extern const char PYTHON_META_TENSOR_FLAG[];
|
||||
extern const char PYTHON_ENVINSTANCE_FLAG[];
|
||||
extern const char PYTHON_DTYPE_FLAG[];
|
||||
extern const char PYTHON_CELL_AS_LIST[];
|
||||
|
|
|
@ -71,6 +71,11 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
py::tuple v(1);
|
||||
v[0] = value->cast<tensor::TensorPtr>();
|
||||
ret = v[0];
|
||||
} else if (value->isa<tensor::MetaTensor>()) {
|
||||
MS_LOG(DEBUG) << "MetaTensor";
|
||||
py::tuple v(1);
|
||||
v[0] = value->cast<tensor::MetaTensorPtr>();
|
||||
ret = v[0];
|
||||
} else if (value->isa<RefKey>()) {
|
||||
MS_LOG(DEBUG) << "RefKey";
|
||||
py::tuple v(1);
|
||||
|
|
|
@ -326,6 +326,12 @@ class _Executor:
|
|||
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
||||
format(type(params)))
|
||||
|
||||
def _params_init_data(self, obj, params):
|
||||
if params is not None:
|
||||
for _, param in params.items():
|
||||
param.init_data()
|
||||
obj.init_parameters_data()
|
||||
|
||||
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
@ -371,6 +377,7 @@ class _Executor:
|
|||
if not do_convert:
|
||||
return phase, True
|
||||
|
||||
self._params_init_data(obj, params)
|
||||
if not enable_debug_runtime or enable_ge:
|
||||
if auto_parallel_mode:
|
||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
|
|
|
@ -39,6 +39,8 @@ class Initializer:
|
|||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
|
||||
def _initialize(self, *kwargs):
|
||||
raise NotImplementedError('Must be overridden!')
|
||||
|
@ -46,6 +48,32 @@ class Initializer:
|
|||
def __call__(self, arr):
|
||||
return self._initialize(arr)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
@shape.setter
|
||||
def shape(self, shape):
|
||||
self._shape = shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, dtype):
|
||||
self._dtype = dtype
|
||||
|
||||
def to_tensor(self):
|
||||
arr = None
|
||||
try:
|
||||
arr = np.ndarray(self.shape)
|
||||
except ValueError:
|
||||
msg = "Error shape={}".format(self.shape)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
self.__call__(arr)
|
||||
return Tensor(arr, dtype=self.dtype)
|
||||
|
||||
def _register(*aliases):
|
||||
"""Return the alias register."""
|
||||
|
@ -279,13 +307,14 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, initialized tensor.
|
||||
Union[Tensor, Initialized], When `init` is Tensor, the return is Tensor object,
|
||||
otherwise the return is Initialize object.
|
||||
|
||||
Examples:
|
||||
>>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
|
||||
"""
|
||||
if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
|
||||
raise TypeError('Unsupported init type.')
|
||||
raise TypeError("Unsupported init type '{}'.".format(type(init)))
|
||||
|
||||
if isinstance(init, Tensor):
|
||||
init_shape = init.shape()
|
||||
|
@ -295,23 +324,32 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
"the variable shape {}.".format(list(init.shape()), shape))
|
||||
return init
|
||||
|
||||
if isinstance(init, str):
|
||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||
if init_obj is None:
|
||||
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
||||
init = init_obj
|
||||
|
||||
if isinstance(shape, list):
|
||||
shape = tuple(shape)
|
||||
elif isinstance(shape, numbers.Number):
|
||||
shape = (shape,)
|
||||
try:
|
||||
arr = np.ndarray(shape)
|
||||
np.ndarray(shape)
|
||||
except ValueError:
|
||||
msg = "Error shape={}".format(shape)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
raise ValueError("Error shape={}".format(shape))
|
||||
|
||||
if isinstance(init, Initializer):
|
||||
init.shape = shape
|
||||
init.dtype = dtype
|
||||
return init
|
||||
|
||||
if isinstance(init, numbers.Number):
|
||||
init_obj = Constant(init)
|
||||
elif isinstance(init, str):
|
||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||
else:
|
||||
init_obj = init
|
||||
|
||||
init_obj(arr)
|
||||
return Tensor(arr, dtype=dtype)
|
||||
|
||||
init_obj.shape = shape
|
||||
init_obj.dtype = dtype
|
||||
return init_obj
|
||||
raise TypeError("Unsupported init type '{}'.".format(type(init)))
|
||||
|
||||
__all__ = [
|
||||
'Initializer',
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Parameter for cell."""
|
||||
import numbers
|
||||
from copy import copy, deepcopy
|
||||
from .initializer import initializer
|
||||
from .tensor import Tensor
|
||||
from .initializer import initializer, Initializer
|
||||
from .tensor import Tensor, MetaTensor
|
||||
from .._checkparam import _check_str_by_regular
|
||||
from ..parallel._utils import _set_clone_info, _CloneInfo
|
||||
|
||||
|
@ -41,7 +42,8 @@ class Parameter:
|
|||
Each parameter of Cell is represented by Parameter class.
|
||||
|
||||
Args:
|
||||
default_input (Tensor): A parameter tensor.
|
||||
default_input (Union[Tensor, Initializer]): Parameter data, when `default_input` is` Initializer`,
|
||||
the data stored by Parameter is `MetaTensor`, otherwise it is `Tensor`.
|
||||
name (str): Name of the child parameter.
|
||||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
||||
|
@ -123,7 +125,11 @@ class Parameter:
|
|||
if init != 'same':
|
||||
shape = self.default_input.shape()
|
||||
dtype = self.default_input.dtype()
|
||||
x.default_input = initializer(init, shape=shape, dtype=dtype)
|
||||
if isinstance(init, (str, Initializer, numbers.Number)):
|
||||
x.init_mode = initializer(init, shape=shape, dtype=dtype)
|
||||
x.default_input = MetaTensor(dtype, shape)
|
||||
else:
|
||||
x.default_input = initializer(init, shape=shape, dtype=dtype)
|
||||
|
||||
x.clone_info = copy(self.clone_info)
|
||||
_set_clone_info(self.clone_info, x.clone_info)
|
||||
|
@ -181,11 +187,21 @@ class Parameter:
|
|||
if isinstance(data, Tensor):
|
||||
# make a copy of Tensor to init the parameter
|
||||
data = Tensor(data.asnumpy().copy())
|
||||
elif isinstance(data, Initializer):
|
||||
self.init_mode = data
|
||||
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape)
|
||||
else:
|
||||
data = Tensor(data)
|
||||
self.default_input = data
|
||||
|
||||
|
||||
def init_data(self):
|
||||
if not isinstance(self.default_input, MetaTensor):
|
||||
return
|
||||
self.default_input = self.init_mode.to_tensor()
|
||||
self.init_mode = None
|
||||
|
||||
|
||||
class ParameterTuple(tuple):
|
||||
"""
|
||||
Class for storing tuple of parameters.
|
||||
|
|
|
@ -92,7 +92,7 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
activation=config.hidden_act).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
self.output_bias = Parameter(
|
||||
initializer(
|
||||
'zero',
|
||||
|
|
|
@ -190,7 +190,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm(embedding_size)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
|
@ -246,7 +246,7 @@ class BertOutput(nn.Cell):
|
|||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.add = P.TensorAdd()
|
||||
self.layernorm = nn.LayerNorm(out_channels).to_float(compute_type)
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
|
@ -802,13 +802,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
|||
|
||||
if not self.input_mask_from_dataset:
|
||||
self.input_mask = initializer(
|
||||
"ones", [config.batch_size, config.seq_length], mstype.int32)
|
||||
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (config.batch_size, 1, config.seq_length)
|
||||
self.broadcast_ones = initializer(
|
||||
"ones", [config.batch_size, config.seq_length, 1], mstype.float32)
|
||||
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
|
||||
self.batch_matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self, input_mask):
|
||||
|
@ -854,7 +854,7 @@ class BertModel(nn.Cell):
|
|||
|
||||
if not self.token_type_ids_from_dataset:
|
||||
self.token_type_ids = initializer(
|
||||
"zeros", [self.batch_size, self.seq_length], mstype.int32)
|
||||
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
|
||||
|
||||
self.bert_embedding_lookup = EmbeddingLookup(
|
||||
vocab_size=config.vocab_size,
|
||||
|
|
|
@ -29,7 +29,7 @@ from .mobilenet import InvertedResidual, ConvBNReLU
|
|||
|
||||
def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
|
||||
weight_shape = (out_channel, in_channel, kernel_size, kernel_size)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||
padding=0, pad_mode=pad_mod, weight_init=weight)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def _make_layer(base, batch_norm):
|
|||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
conv2d = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
kernel_size=3,
|
||||
|
|
|
@ -163,6 +163,7 @@ class Cell:
|
|||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
self.init_parameters_data()
|
||||
output = self.construct(*inputs)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
|
@ -395,6 +396,10 @@ class Cell:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def init_parameters_data(self, recurse=True):
|
||||
for param in self.get_parameters(expand=recurse):
|
||||
param.init_data()
|
||||
|
||||
def parameters_dict(self, recurse=True):
|
||||
"""
|
||||
Gets parameters dictionary.
|
||||
|
|
|
@ -471,6 +471,9 @@ class LayerNorm(Cell):
|
|||
beta_init='zeros',
|
||||
):
|
||||
super(LayerNorm, self).__init__()
|
||||
if not isinstance(normalized_shape, (tuple, list)):
|
||||
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||
.format(normalized_shape, type(normalized_shape)))
|
||||
self.normalized_shape = normalized_shape
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.begin_params_axis = begin_params_axis
|
||||
|
|
|
@ -116,6 +116,8 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
|||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = param["name"]
|
||||
param_tensor = param_value.tensor
|
||||
if isinstance(param["data"], Parameter):
|
||||
param["data"].init_data()
|
||||
param_data = param["data"].asnumpy().reshape(-1)
|
||||
param_tensor.tensor_content = param_data.tostring()
|
||||
param_tensor.tensor_type = str(param["data"].dtype())
|
||||
|
@ -238,6 +240,7 @@ def load_param_into_net(net, parameter_dict):
|
|||
logger.error("Failed to combine the net and the parameters.")
|
||||
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
|
||||
raise TypeError(msg)
|
||||
param.init_data()
|
||||
_update_param(param, new_param)
|
||||
else:
|
||||
param_not_load.append(param.name)
|
||||
|
@ -311,6 +314,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
|
|||
param_list = []
|
||||
for (key, value) in param_dict.items():
|
||||
each_param = {"name": key}
|
||||
value.init_data()
|
||||
if isinstance(value.data, Tensor):
|
||||
param_data = value.data
|
||||
else:
|
||||
|
@ -371,6 +375,8 @@ def _fill_param_into_net(net, parameter_list):
|
|||
parameter_dict = {}
|
||||
for each_param in parameter_list:
|
||||
param_name = each_param["name"]
|
||||
if isinstance(each_param["data"], Parameter):
|
||||
each_param["data"].init_data()
|
||||
np_val = each_param["data"].asnumpy()
|
||||
if np_val.shape == (1,):
|
||||
parameter_dict[param_name] = Parameter(np_val, name=param_name)
|
||||
|
|
|
@ -35,6 +35,7 @@ def get_uniform_with_shape(shape):
|
|||
def set_block_param_with_rand(net, rand_func=None):
|
||||
if not isinstance(net, nn.Cell) or rand_func is None:
|
||||
return
|
||||
net.init_parameters_data()
|
||||
for param in net.trainable_params():
|
||||
param.default_input = Tensor(rand_func(param.default_input.asnumpy().shape))
|
||||
|
||||
|
|
|
@ -143,6 +143,7 @@ def test_bert_tdt():
|
|||
callback = ModelCallback()
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
param.init_data()
|
||||
value = param.default_input
|
||||
name = param.name
|
||||
if isinstance(value, Tensor):
|
||||
|
|
|
@ -223,6 +223,7 @@ def test_div():
|
|||
@non_graph_engine
|
||||
def test_parameter():
|
||||
x = Parameter(initializer(1, [1], ms.float32), name="beta1_power")
|
||||
x.init_data()
|
||||
z = x / 2
|
||||
print(z)
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_dense_str_activation():
|
|||
assert isinstance(dense.activation, nn.ReLU)
|
||||
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
|
||||
dense.construct(input_data)
|
||||
dense(input_data)
|
||||
|
||||
|
||||
def test_dense_weight_error():
|
||||
|
|
|
@ -40,8 +40,10 @@ class ParameterNet(nn.Cell):
|
|||
def test_using_same_seed_for_initializer():
|
||||
np.random.seed(0)
|
||||
net1 = ParameterNet()
|
||||
net1.init_parameters_data()
|
||||
np.random.seed(0)
|
||||
net2 = ParameterNet()
|
||||
net2.init_parameters_data()
|
||||
for key in net1.parameters_dict():
|
||||
if key not in net2.parameters_dict():
|
||||
assert False
|
||||
|
@ -52,8 +54,10 @@ def test_using_same_seed_for_initializer():
|
|||
def test_using_diffserent_seed_for_initializer():
|
||||
np.random.seed(0)
|
||||
net1 = ParameterNet()
|
||||
net1.init_parameters_data()
|
||||
np.random.seed(1)
|
||||
net2 = ParameterNet()
|
||||
net2.init_parameters_data()
|
||||
for key in net1.parameters_dict():
|
||||
if key not in net2.parameters_dict():
|
||||
assert False
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_bn2d():
|
|||
|
||||
#3-channel RGB
|
||||
input_data = Tensor(np.random.randint(0, 1, [1, 3, 224, 224]).astype(np.float32))
|
||||
output = bn.construct(input_data)
|
||||
output = bn(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -68,7 +68,7 @@ def test_bn1d():
|
|||
"""ut of nn.BatchNorm1d"""
|
||||
bn = nn.BatchNorm1d(3)
|
||||
input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32))
|
||||
output = bn.construct(input_data)
|
||||
output = bn(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ kernel_size = 3
|
|||
|
||||
def test_check_conv2d_1():
|
||||
m = nn.Conv2d(3, 64, 3, bias_init='zeros')
|
||||
output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -35,7 +35,7 @@ def test_check_conv2d_1():
|
|||
def test_check_conv2d_2():
|
||||
Tensor(np.ones([2, 2]))
|
||||
m = nn.Conv2d(3, 64, 4, has_bias=False, weight_init='normal')
|
||||
output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -43,7 +43,7 @@ def test_check_conv2d_2():
|
|||
def test_check_conv2d_3():
|
||||
Tensor(np.ones([2, 2]))
|
||||
m = nn.Conv2d(3, 64, (3, 3))
|
||||
output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -51,13 +51,13 @@ def test_check_conv2d_3():
|
|||
def test_check_conv2d_4():
|
||||
Tensor(np.ones([2, 2]))
|
||||
m = nn.Conv2d(3, 64, (3, 3), stride=2, pad_mode='pad', padding=4)
|
||||
output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
||||
|
||||
def test_check_conv2d_bias():
|
||||
m = nn.Conv2d(3, 64, 3, bias_init='zeros')
|
||||
output = m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output = m(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_dense_defaultbias_noactivation():
|
|||
assert dense.activation is None
|
||||
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3]).astype(np.float32))
|
||||
output = dense.construct(input_data)
|
||||
output = dense(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_dense_defaultweight():
|
|||
dense = nn.Dense(3, 2, bias_init=bias)
|
||||
#batch_size 1 && 3-channel RGB
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3]).astype(np.float32))
|
||||
output = dense.construct(input_data)
|
||||
output = dense(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -48,7 +48,7 @@ def test_dense_bias():
|
|||
dense = nn.Dense(3, 2, weight, bias)
|
||||
|
||||
input_data = Tensor(np.random.randint(0, 255, [2, 3]).astype(np.float32))
|
||||
output = dense.construct(input_data)
|
||||
output = dense(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -58,7 +58,7 @@ def test_dense_nobias():
|
|||
dense = nn.Dense(3, 2, weight, has_bias=False)
|
||||
|
||||
input_data = Tensor(np.random.randint(0, 255, [2, 3]).astype(np.float32))
|
||||
output = dense.construct(input_data)
|
||||
output = dense(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0], (np.float32, np.float64))
|
||||
|
||||
|
@ -73,7 +73,7 @@ def test_dense_str_activation():
|
|||
assert isinstance(dense.activation, nn.ReLU)
|
||||
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
|
||||
output = dense.construct(input_data)
|
||||
output = dense(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0], np.float32)
|
||||
|
||||
|
|
|
@ -264,6 +264,7 @@ def test_grad_inline_bprop_multi_input():
|
|||
net = InlineMutilTwoInputParameterCell()
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
net.init_parameters_data()
|
||||
grads = C.grad_all(net)(input1, input2)
|
||||
assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
|
||||
|
|
|
@ -133,6 +133,6 @@ def test_lenet_grad():
|
|||
print("fail to run optimizer")
|
||||
# verification
|
||||
if i == verification_step:
|
||||
fw_output = net.construct(input_data)
|
||||
loss_output = loss.construct(fw_output, label)
|
||||
fw_output = net(input_data)
|
||||
loss_output = loss(fw_output, label)
|
||||
print("The loss of %s-th iteration is %s" % (i, loss_output.asnumpy()))
|
||||
|
|
|
@ -151,7 +151,7 @@ def test_softmaxloss_grad():
|
|||
predict = Tensor(np.ones([1, 64]))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
print("pynative run")
|
||||
out = net.construct(predict, label)
|
||||
out = net(predict, label)
|
||||
print("out:", out)
|
||||
|
||||
def test_stop_gradient_1():
|
||||
|
|
|
@ -22,6 +22,10 @@ from scipy import stats
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.common.initializer as init
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
from mindspore.nn import Conv2d
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
@ -55,8 +59,8 @@ def _check_uniform(tensor, boundary_a, boundary_b):
|
|||
|
||||
def test_init_Initializer():
|
||||
tensor = init.initializer(InitTwo(), [2, 2], ms.int32)
|
||||
assert tensor.shape() == (2, 2)
|
||||
_check_value(tensor, 2, 2)
|
||||
assert tensor.shape == (2, 2)
|
||||
_check_value(tensor.to_tensor(), 2, 2)
|
||||
|
||||
|
||||
def test_init_tensor():
|
||||
|
@ -67,71 +71,71 @@ def test_init_tensor():
|
|||
|
||||
def test_init_zero_default_dtype():
|
||||
tensor = init.initializer(init.Zero(), [2, 2])
|
||||
assert tensor.dtype() == ms.float32
|
||||
_check_value(tensor, 0, 0)
|
||||
assert tensor.dtype == ms.float32
|
||||
_check_value(tensor.to_tensor(), 0, 0)
|
||||
|
||||
|
||||
def test_init_zero():
|
||||
tensor = init.initializer(init.Zero(), [2, 2], ms.float32)
|
||||
_check_value(tensor, 0, 0)
|
||||
_check_value(tensor.to_tensor(), 0, 0)
|
||||
|
||||
|
||||
def test_init_zero_alias_default_dtype():
|
||||
tensor = init.initializer('zeros', [1, 2])
|
||||
assert tensor.dtype() == ms.float32
|
||||
_check_value(tensor, 0, 0)
|
||||
assert tensor.dtype == ms.float32
|
||||
_check_value(tensor.to_tensor(), 0, 0)
|
||||
|
||||
|
||||
def test_init_zero_alias():
|
||||
tensor = init.initializer('zeros', [1, 2], ms.float32)
|
||||
_check_value(tensor, 0, 0)
|
||||
_check_value(tensor.to_tensor(), 0, 0)
|
||||
|
||||
|
||||
def test_init_one():
|
||||
tensor = init.initializer(init.One(), [2, 2], ms.float32)
|
||||
_check_value(tensor, 1, 1)
|
||||
_check_value(tensor.to_tensor(), 1, 1)
|
||||
|
||||
|
||||
def test_init_one_alias():
|
||||
tensor = init.initializer('ones', [1, 2], ms.float32)
|
||||
_check_value(tensor, 1, 1)
|
||||
_check_value(tensor.to_tensor(), 1, 1)
|
||||
|
||||
|
||||
def test_init_constant():
|
||||
tensor = init.initializer(init.Constant(1), [2, 2], ms.float32)
|
||||
_check_value(tensor, 1, 1)
|
||||
_check_value(tensor.to_tensor(), 1, 1)
|
||||
|
||||
|
||||
def test_init_uniform():
|
||||
scale = 10
|
||||
tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)
|
||||
_check_value(tensor, -scale, scale)
|
||||
_check_value(tensor.to_tensor(), -scale, scale)
|
||||
|
||||
|
||||
def test_init_uniform_alias():
|
||||
scale = 100
|
||||
tensor = init.initializer('uniform', [5, 4], ms.float32)
|
||||
_check_value(tensor, -scale, scale)
|
||||
_check_value(tensor.to_tensor(), -scale, scale)
|
||||
|
||||
|
||||
def test_init_normal():
|
||||
tensor = init.initializer(init.Normal(), [5, 4], ms.float32)
|
||||
assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
|
||||
assert isinstance(tensor, init.Normal), 'Normal init failed!'
|
||||
|
||||
|
||||
def test_init_truncated_normal():
|
||||
tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32)
|
||||
assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
|
||||
assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!'
|
||||
|
||||
|
||||
def test_init_normal_alias():
|
||||
tensor = init.initializer('normal', [5, 4], ms.float32)
|
||||
assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
|
||||
assert isinstance(tensor, init.Normal), 'Normal init failed!'
|
||||
|
||||
|
||||
def test_init_truncatednormal_alias():
|
||||
tensor = init.initializer('truncatednormal', [5, 4], ms.float32)
|
||||
assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
|
||||
assert isinstance(tensor, init.TruncatedNormal), 'TruncatedNormal init failed!'
|
||||
|
||||
|
||||
def test_init_abnormal():
|
||||
|
@ -142,12 +146,12 @@ def test_init_abnormal():
|
|||
def test_init_xavier_uniform():
|
||||
""" test_init_xavier_uniform """
|
||||
gain = 1.2
|
||||
tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32)
|
||||
tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32)
|
||||
tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32)
|
||||
tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32)
|
||||
tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32)
|
||||
tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32)
|
||||
tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).to_tensor()
|
||||
tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).to_tensor()
|
||||
tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).to_tensor()
|
||||
tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).to_tensor()
|
||||
tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).to_tensor()
|
||||
tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).to_tensor()
|
||||
tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None}
|
||||
|
||||
for tensor, gain_value in tensor_dict.items():
|
||||
|
@ -167,7 +171,7 @@ def test_init_xavier_uniform():
|
|||
|
||||
def test_init_xavier_uniform_error():
|
||||
with py.raises(ValueError):
|
||||
init.initializer(init.XavierUniform(), [6], ms.float32)
|
||||
init.initializer(init.XavierUniform(), [6], ms.float32).to_tensor()
|
||||
|
||||
|
||||
def test_init_he_uniform():
|
||||
|
@ -176,7 +180,7 @@ def test_init_he_uniform():
|
|||
tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32)
|
||||
tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32)
|
||||
tensor4 = init.initializer('he_uniform', [20, 22], ms.float32)
|
||||
tensors = [tensor1, tensor2, tensor3, tensor4]
|
||||
tensors = [tensor1.to_tensor(), tensor2.to_tensor(), tensor3.to_tensor(), tensor4.to_tensor()]
|
||||
|
||||
for tensor in tensors:
|
||||
shape = tensor.asnumpy().shape
|
||||
|
@ -192,7 +196,7 @@ def test_init_he_uniform():
|
|||
|
||||
def test_init_he_uniform_error():
|
||||
with py.raises(ValueError):
|
||||
init.initializer(init.HeUniform(), [6], ms.float32)
|
||||
init.initializer(init.HeUniform(), [6], ms.float32).to_tensor()
|
||||
|
||||
|
||||
def test_conv2d_abnormal_kernel_negative():
|
||||
|
@ -216,9 +220,30 @@ def test_conv2d_abnormal_kernel_normal():
|
|||
|
||||
@non_graph_engine
|
||||
def test_conv2d_abnormal_kernel_truncated_normal():
|
||||
input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32)
|
||||
input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).to_tensor()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = ms.Model(
|
||||
Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3,
|
||||
padding=0, weight_init="truncatednormal"))
|
||||
model.predict(input_data)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.t1 = Parameter(init.initializer('uniform', [5, 4], ms.float32), name="w1")
|
||||
self.t2 = Parameter(init.initializer(init.TruncatedNormal(), [5, 4], ms.float32), name="w2")
|
||||
|
||||
def construct(self, x):
|
||||
z = self.add(x, self.t1)
|
||||
z = self.add(z, self.t2)
|
||||
return z
|
||||
|
||||
def test_weight_shape():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
a = np.arange(20).reshape(5, 4)
|
||||
t = Tensor(a, dtype=ms.float32)
|
||||
net = Net()
|
||||
out = net(t)
|
||||
print(out)
|
||||
|
|
|
@ -198,6 +198,7 @@ def test_load_param_into_net_error_dict():
|
|||
|
||||
def test_load_param_into_net_erro_dict_param():
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
@ -210,6 +211,7 @@ def test_load_param_into_net_erro_dict_param():
|
|||
def test_load_param_into_net_has_more_param():
|
||||
""" test_load_param_into_net_has_more_param """
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
@ -225,6 +227,7 @@ def test_load_param_into_net_has_more_param():
|
|||
|
||||
def test_load_param_into_net_param_type_and_shape_error():
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
@ -236,6 +239,7 @@ def test_load_param_into_net_param_type_and_shape_error():
|
|||
|
||||
def test_load_param_into_net_param_type_error():
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
@ -248,6 +252,7 @@ def test_load_param_into_net_param_type_error():
|
|||
|
||||
def test_load_param_into_net_param_shape_error():
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
@ -260,6 +265,7 @@ def test_load_param_into_net_param_shape_error():
|
|||
|
||||
def test_load_param_into_net():
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
|
|
Loading…
Reference in New Issue