forked from mindspore-Ecosystem/mindspore
!2040 fix paramter is metatensor bug in pynative mode
Merge pull request !2040 from flywind/fix_pynative_bug
This commit is contained in:
commit
89fce0e41f
|
@ -16,6 +16,7 @@
|
|||
"""Parameter for cell."""
|
||||
import numbers
|
||||
from copy import copy, deepcopy
|
||||
from mindspore import context
|
||||
from . import dtype as mstype
|
||||
from .initializer import initializer, Initializer
|
||||
from .tensor import Tensor, MetaTensor
|
||||
|
@ -61,6 +62,8 @@ class Parameter:
|
|||
self._is_init = False
|
||||
self._sliced = False
|
||||
self.clone_info = _CloneInfo()
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
self.init_data()
|
||||
|
||||
def __repr__(self):
|
||||
format_str = 'Parameter (name={name})'
|
||||
|
@ -142,6 +145,8 @@ class Parameter:
|
|||
if isinstance(init, (str, Initializer, numbers.Number)):
|
||||
x.init_mode = initializer(init, shape=shape, dtype=dtype)
|
||||
x.default_input = MetaTensor(dtype, shape)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
x.init_data()
|
||||
else:
|
||||
x.default_input = initializer(init, shape=shape, dtype=dtype)
|
||||
|
||||
|
|
|
@ -202,7 +202,6 @@ class Cell:
|
|||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
self.init_parameters_data()
|
||||
orign_grad = []
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
|
|
Loading…
Reference in New Issue