forked from mindspore-Ecosystem/mindspore
!1087 add param_perfix to cell
Merge pull request !1087 from SanjayChan/03cell
This commit is contained in:
commit
00b7877ec4
|
@ -60,6 +60,7 @@ class Cell:
|
||||||
self._cells = OrderedDict()
|
self._cells = OrderedDict()
|
||||||
self.training = False
|
self.training = False
|
||||||
self.pynative = False
|
self.pynative = False
|
||||||
|
self._param_perfix = ''
|
||||||
self._auto_prefix = auto_prefix
|
self._auto_prefix = auto_prefix
|
||||||
self._scope = None
|
self._scope = None
|
||||||
self._phase = 'train'
|
self._phase = 'train'
|
||||||
|
@ -83,6 +84,24 @@ class Cell:
|
||||||
def cell_init_args(self):
|
def cell_init_args(self):
|
||||||
return self._cell_init_args
|
return self._cell_init_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_perfix(self):
|
||||||
|
"""
|
||||||
|
Param perfix is the prfix of curent cell's direct child parameter.
|
||||||
|
"""
|
||||||
|
return self._param_perfix
|
||||||
|
|
||||||
|
def update_cell_prefix(self):
|
||||||
|
"""
|
||||||
|
Update the all child cells' self.param_prefix.
|
||||||
|
|
||||||
|
After invoked, can get all the cell's children's name perfix by '_param_perfix'.
|
||||||
|
"""
|
||||||
|
cells = self.cells_and_names
|
||||||
|
|
||||||
|
for cell_name, cell in cells:
|
||||||
|
cell._param_perfix = cell_name
|
||||||
|
|
||||||
@cell_init_args.setter
|
@cell_init_args.setter
|
||||||
def cell_init_args(self, value):
|
def cell_init_args(self, value):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
|
@ -223,7 +242,6 @@ class Cell:
|
||||||
Args:
|
Args:
|
||||||
params (dict): The parameters dictionary used for init data graph.
|
params (dict): The parameters dictionary used for init data graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if params is None:
|
if params is None:
|
||||||
for key in self.parameters_dict():
|
for key in self.parameters_dict():
|
||||||
tensor = self.parameters_dict()[key].data
|
tensor = self.parameters_dict()[key].data
|
||||||
|
@ -253,7 +271,6 @@ class Cell:
|
||||||
Args:
|
Args:
|
||||||
inputs (Function or Cell): inputs of construct method.
|
inputs (Function or Cell): inputs of construct method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parallel_inputs_run = []
|
parallel_inputs_run = []
|
||||||
if len(inputs) > self._construct_inputs_num:
|
if len(inputs) > self._construct_inputs_num:
|
||||||
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.
|
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.
|
||||||
|
|
Loading…
Reference in New Issue