forked from mindspore-Ecosystem/mindspore
!461 Add interface to get attributes of network.
Merge pull request !461 from wsc/master
This commit is contained in:
commit
507b63ea20
|
@ -56,7 +56,7 @@ class Cell:
|
|||
>>> def construct(self, x):
|
||||
>>> return self.relu(x)
|
||||
"""
|
||||
def __init__(self, auto_prefix=True):
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
self._params = OrderedDict()
|
||||
self._cells = OrderedDict()
|
||||
self.training = False
|
||||
|
@ -74,6 +74,8 @@ class Cell:
|
|||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||
self._get_construct_inputs_number_and_name()
|
||||
self._parallel_inputs_run = None
|
||||
if flags:
|
||||
self.add_flags(**flags)
|
||||
|
||||
@property
|
||||
def create_time(self):
|
||||
|
@ -607,6 +609,11 @@ class Cell:
|
|||
cell.add_flags_recursive(**flags)
|
||||
return self
|
||||
|
||||
def get_flags(self):
|
||||
if not hasattr(self, "_mindspore_flags"):
|
||||
self._mindspore_flags = {}
|
||||
return self._mindspore_flags
|
||||
|
||||
def to_float(self, dst_type):
|
||||
"""
|
||||
Add cast on all inputs of cell and child cells to run with certain float type.
|
||||
|
|
|
@ -219,7 +219,7 @@ class DataWrapper(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(DataWrapper, self).__init__(auto_prefix=False)
|
||||
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
|
Loading…
Reference in New Issue