From 362b193bdca82847fcfc057fc0626feb721fc91b Mon Sep 17 00:00:00 2001 From: lilei Date: Sat, 19 Sep 2020 10:10:27 +0800 Subject: [PATCH] modify bud of memory error is thrown --- mindspore/common/initializer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 1da57645105..24faa74ac06 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -15,6 +15,7 @@ """Initializer for cell parameters.""" import numbers import math +import copy from functools import reduce import numpy as np @@ -82,7 +83,7 @@ class Initializer: shape = self.shape try: - arr = np.ndarray(shape) + arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype)) except ValueError: msg = "Error shape={}".format(shape) logger.error(msg) @@ -478,9 +479,10 @@ def initializer(init, shape=None, dtype=mstype.float32): raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}") if isinstance(init, Initializer): - init.shape = init.shape if init.shape is not None else shape - init.dtype = init.dtype if init.dtype is not None else dtype - return init + init_copy = copy.deepcopy(init) + init_copy.shape = shape if shape is not None else init.shape + init_copy.dtype = init.dtype if init.dtype is not None else dtype + return init_copy if isinstance(init, str): init_obj = _INITIALIZER_ALIAS[init.lower()]()