!2040 fix paramter is metatensor bug in pynative mode

Merge pull request !2040 from flywind/fix_pynative_bug
This commit is contained in:
mindspore-ci-bot 2020-06-12 15:28:45 +08:00 committed by Gitee
commit 89fce0e41f
2 changed files with 5 additions and 1 deletions

View File

@ -16,6 +16,7 @@
"""Parameter for cell."""
import numbers
from copy import copy, deepcopy
from mindspore import context
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
@ -61,6 +62,8 @@ class Parameter:
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
def __repr__(self):
format_str = 'Parameter (name={name})'
@ -142,6 +145,8 @@ class Parameter:
if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)

View File

@ -202,7 +202,6 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)