fix pynative paramter is metatensor bug

This commit is contained in:
kpy 2020-06-12 10:16:54 +08:00
parent cedfc7fac0
commit 0b1559b8d0
2 changed files with 5 additions and 1 deletions

View File

@ -16,6 +16,7 @@
"""Parameter for cell."""
import numbers
from copy import copy, deepcopy
from mindspore import context
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
@ -61,6 +62,8 @@ class Parameter:
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
def __repr__(self):
format_str = 'Parameter (name={name})'
@ -142,6 +145,8 @@ class Parameter:
if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)

View File

@ -201,7 +201,6 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)