!1087 add param_perfix to cell

Merge pull request !1087 from SanjayChan/03cell
This commit is contained in:
mindspore-ci-bot 2020-05-12 16:27:16 +08:00 committed by Gitee
commit 00b7877ec4
1 changed files with 19 additions and 2 deletions

View File

@ -60,6 +60,7 @@ class Cell:
self._cells = OrderedDict() self._cells = OrderedDict()
self.training = False self.training = False
self.pynative = False self.pynative = False
self._param_perfix = ''
self._auto_prefix = auto_prefix self._auto_prefix = auto_prefix
self._scope = None self._scope = None
self._phase = 'train' self._phase = 'train'
@ -83,6 +84,24 @@ class Cell:
def cell_init_args(self): def cell_init_args(self):
return self._cell_init_args return self._cell_init_args
@property
def param_perfix(self):
"""
Param perfix is the prfix of curent cell's direct child parameter.
"""
return self._param_perfix
def update_cell_prefix(self):
"""
Update the all child cells' self.param_prefix.
After invoked, can get all the cell's children's name perfix by '_param_perfix'.
"""
cells = self.cells_and_names
for cell_name, cell in cells:
cell._param_perfix = cell_name
@cell_init_args.setter @cell_init_args.setter
def cell_init_args(self, value): def cell_init_args(self, value):
if not isinstance(value, str): if not isinstance(value, str):
@ -223,7 +242,6 @@ class Cell:
Args: Args:
params (dict): The parameters dictionary used for init data graph. params (dict): The parameters dictionary used for init data graph.
""" """
if params is None: if params is None:
for key in self.parameters_dict(): for key in self.parameters_dict():
tensor = self.parameters_dict()[key].data tensor = self.parameters_dict()[key].data
@ -253,7 +271,6 @@ class Cell:
Args: Args:
inputs (Function or Cell): inputs of construct method. inputs (Function or Cell): inputs of construct method.
""" """
parallel_inputs_run = [] parallel_inputs_run = []
if len(inputs) > self._construct_inputs_num: if len(inputs) > self._construct_inputs_num:
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'. raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.