forked from mindspore-Ecosystem/mindspore
remove attr support_non_tensor_input of cell
This commit is contained in:
parent
02adaa7528
commit
7eaf84d07a
|
@ -440,59 +440,19 @@ class _Executor:
|
|||
Str, the full phase of the cell.
|
||||
Bool, if the graph has been compiled before, return False, else return True.
|
||||
"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
class InputsToAttrCell(nn.Cell):
|
||||
"""The cell that converts non-tensor inputs to attr."""
|
||||
|
||||
def __init__(self, net, args_names, non_tensor_inputs):
|
||||
super(InputsToAttrCell, self).__init__()
|
||||
self.net = net
|
||||
self.args_names = args_names
|
||||
self.non_tensor_inputs = non_tensor_inputs
|
||||
self.inputs_to_attr = True
|
||||
|
||||
def construct(self, *tensor_inputs):
|
||||
real_inputs = ()
|
||||
index = 0
|
||||
for i in args_names:
|
||||
if i in self.non_tensor_inputs.keys():
|
||||
real_inputs += (self.non_tensor_inputs[i],)
|
||||
else:
|
||||
real_inputs += (tensor_inputs[index],)
|
||||
index += 1
|
||||
return self.net(*real_inputs)
|
||||
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
if not hasattr(obj, "inputs_to_attr"):
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
obj.phase_prefix = str(key[1])
|
||||
if 'export' in phase:
|
||||
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
|
||||
else:
|
||||
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
obj.phase_prefix = str(key[1])
|
||||
if 'export' in phase:
|
||||
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
|
||||
else:
|
||||
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
|
||||
|
||||
if phase in self.compile_cache.keys():
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
|
||||
if getattr(obj, "support_non_tensor_inputs", None):
|
||||
for i in obj.__dict__.values():
|
||||
if isinstance(i, GradOperation):
|
||||
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
|
||||
"only support forward net.")
|
||||
attrs = {}
|
||||
inputs = []
|
||||
for key, value in dic.items():
|
||||
if not isinstance(value, (Tensor, MetaTensor)):
|
||||
attrs[key] = value
|
||||
else:
|
||||
inputs.append(value)
|
||||
if attrs:
|
||||
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
|
||||
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
|
||||
if phase in self.compile_cache.keys():
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
|
||||
obj.check_names()
|
||||
_check_full_batch()
|
||||
|
|
|
@ -107,7 +107,6 @@ class Cell(Cell_):
|
|||
self._bprop_debug = False
|
||||
self.cell_type = None
|
||||
self._auto_parallel_compile_and_run = False
|
||||
self._support_non_tensor_inputs = False
|
||||
|
||||
def __getstate__(self):
|
||||
base = Cell_.__getstate__(self)
|
||||
|
@ -119,27 +118,6 @@ class Cell(Cell_):
|
|||
self.__dict__ = dict_
|
||||
self._attr_synced = False
|
||||
|
||||
@property
|
||||
def support_non_tensor_inputs(self):
|
||||
"""
|
||||
Whether support non tensor inputs in outermost net in GRAPH MODE.
|
||||
This property only used in forward net, and is not supported in grad net.
|
||||
The default value of the property is the `False`, that is,
|
||||
it does not support passing non tensor inputs to the outermost net.
|
||||
If you want to support, set the property to the `True`.
|
||||
|
||||
"""
|
||||
return self._support_non_tensor_inputs
|
||||
|
||||
@support_non_tensor_inputs.setter
|
||||
def support_non_tensor_inputs(self, value):
|
||||
"""
|
||||
Set attr 'support_non_tensor_inputs'.
|
||||
"""
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
|
||||
self._support_non_tensor_inputs = value
|
||||
|
||||
@property
|
||||
def _cell_tag(self):
|
||||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||
|
@ -666,11 +644,6 @@ class Cell(Cell_):
|
|||
"""
|
||||
Defines the computation to be performed. This method must be overridden by all subclasses.
|
||||
|
||||
Note:
|
||||
The outermost net only supports tensor inputs by default.
|
||||
If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`.
|
||||
Refer to the property `support_non_tensor_inputs` description.
|
||||
|
||||
Returns:
|
||||
Tensor, returns the computed result.
|
||||
"""
|
||||
|
|
|
@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict():
|
|||
class TestNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestNet, self).__init__()
|
||||
self.support_non_tensor_inputs = False
|
||||
|
||||
def construct(self, tuple_a, z, list_m, w, s, dict_n):
|
||||
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]
|
||||
|
|
Loading…
Reference in New Issue