remove attr support_non_tensor_input of cell

This commit is contained in:
buxue 2021-01-14 16:19:23 +08:00
parent 02adaa7528
commit 7eaf84d07a
3 changed files with 10 additions and 78 deletions

View File

@ -440,59 +440,19 @@ class _Executor:
Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True.
"""
from mindspore import nn
from mindspore.ops.composite import GradOperation
class InputsToAttrCell(nn.Cell):
"""The cell that converts non-tensor inputs to attr."""
def __init__(self, net, args_names, non_tensor_inputs):
super(InputsToAttrCell, self).__init__()
self.net = net
self.args_names = args_names
self.non_tensor_inputs = non_tensor_inputs
self.inputs_to_attr = True
def construct(self, *tensor_inputs):
real_inputs = ()
index = 0
for i in args_names:
if i in self.non_tensor_inputs.keys():
real_inputs += (self.non_tensor_inputs[i],)
else:
real_inputs += (tensor_inputs[index],)
index += 1
return self.net(*real_inputs)
args_names, args_list = _generate_pip_args(obj, *args)
if not hasattr(obj, "inputs_to_attr"):
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
if getattr(obj, "support_non_tensor_inputs", None):
for i in obj.__dict__.values():
if isinstance(i, GradOperation):
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
"only support forward net.")
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
obj.check_names()
_check_full_batch()

View File

@ -107,7 +107,6 @@ class Cell(Cell_):
self._bprop_debug = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False
def __getstate__(self):
base = Cell_.__getstate__(self)
@ -119,27 +118,6 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False
@property
def support_non_tensor_inputs(self):
"""
Whether support non tensor inputs in outermost net in GRAPH MODE.
This property only used in forward net, and is not supported in grad net.
The default value of the property is the `False`, that is,
it does not support passing non tensor inputs to the outermost net.
If you want to support, set the property to the `True`.
"""
return self._support_non_tensor_inputs
@support_non_tensor_inputs.setter
def support_non_tensor_inputs(self, value):
"""
Set attr 'support_non_tensor_inputs'.
"""
if not isinstance(value, bool):
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
self._support_non_tensor_inputs = value
@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`
@ -666,11 +644,6 @@ class Cell(Cell_):
"""
Defines the computation to be performed. This method must be overridden by all subclasses.
Note:
The outermost net only supports tensor inputs by default.
If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`.
Refer to the property `support_non_tensor_inputs` description.
Returns:
Tensor, returns the computed result.
"""

View File

@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict():
class TestNet(nn.Cell):
def __init__(self):
super(TestNet, self).__init__()
self.support_non_tensor_inputs = False
def construct(self, tuple_a, z, list_m, w, s, dict_n):
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]