!8697 support the outermost layer network inputs are list or dict or scalar

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-01 11:11:43 +08:00 committed by Gitee
commit e45f19adc8
5 changed files with 140 additions and 17 deletions

View File

@ -115,7 +115,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
if (!parse::ConvertData(arg.second, &converted)) {
MS_LOG(EXCEPTION) << "GenerateKey convert arg failed";
}
args_spec.push_back(abstract::FromValue(converted, true));
bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>();
args_spec.push_back(abstract::FromValue(converted, broaden));
}
if (g_args_cache.count(args_spec) == 0) {
static int64_t key = 0;

View File

@ -413,18 +413,54 @@ class _Executor:
Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True.
"""
args_names, args_list = _generate_pip_args(obj, *args)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
self.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
else:
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
from mindspore import nn
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
class InputsToAttrCell(nn.Cell):
"""The cell that converts non-tensor inputs to attr."""
def __init__(self, net, args_names, non_tensor_inputs):
super(InputsToAttrCell, self).__init__()
self.net = net
self.args_names = args_names
self.non_tensor_inputs = non_tensor_inputs
self.inputs_to_attr = True
def construct(self, *tensor_inputs):
real_inputs = ()
index = 0
for i in args_names:
if i in self.non_tensor_inputs.keys():
real_inputs += (self.non_tensor_inputs[i],)
else:
real_inputs += (tensor_inputs[index],)
index += 1
return self.net(*real_inputs)
args_names, args_list = _generate_pip_args(obj, *args)
if not hasattr(obj, "inputs_to_attr"):
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
self.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
else:
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
if getattr(obj, "support_non_tensor_inputs", None):
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
obj.check_names()
_check_full_batch()

View File

@ -30,7 +30,7 @@ from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..ops.functional import cast
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor
from ..common.tensor import Tensor, MetaTensor
class Cell(Cell_):
@ -104,6 +104,7 @@ class Cell(Cell_):
self._already_run = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False
@property
def already_run(self):
@ -119,6 +120,23 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False
@property
def support_non_tensor_inputs(self):
"""
Whether support non tensor inputs in cell `construct` method.
This property only used in forward net, is not supported in grad net.
"""
return self._support_non_tensor_inputs
@support_non_tensor_inputs.setter
def support_non_tensor_inputs(self, value):
"""
Set attr 'support_non_tensor_inputs'.
"""
if not isinstance(value, bool):
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
self._support_non_tensor_inputs = value
@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`
@ -553,14 +571,19 @@ class Cell(Cell_):
self._auto_parallel_compile_and_run = True
self.compile(*inputs)
new_inputs = []
for i in inputs:
if isinstance(i, (Tensor, MetaTensor)):
new_inputs.append(i)
if self._auto_parallel_mode:
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag:
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
parallel_inputs_run = self._parallel_inputs_run
else:
parallel_inputs_run = inputs
parallel_inputs_run = new_inputs
return _executor(self, *parallel_inputs_run, phase=self.phase)
return _executor(self, *inputs, phase=self.phase)
return _executor(self, *new_inputs, phase=self.phase)
def auto_parallel_compile_and_run(self):
return self._auto_parallel_compile_and_run

View File

@ -94,7 +94,7 @@ def restrict_int_index(data_shape, tuple_indexes):
for i, index in enumerate(tuple_indexes):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
const_utils.raise_index_error("The index is out of the data's special dimension range.")
raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_indexes_new += (tuple_indexes[i]+data_shape[i],)
else:

View File

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost net pass scalar tuple list dict"""
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE)
def test_outermost_net_pass_scalar_tuple_list_dict():
class TestNet(nn.Cell):
def __init__(self):
super(TestNet, self).__init__()
self.support_non_tensor_inputs = True
def construct(self, tuple_a, z, list_m, w, s, dict_n):
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=True)
def construct(self, tuple_a, z, list_m, w, s, dict_n):
return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n)
x = Tensor(np.ones((2, 2), np.float32))
y = Tensor(np.ones((2, 2), np.float32) * 2)
z = Tensor(np.ones((2, 2), np.float32) * 3)
w = Tensor(np.ones((2, 2), np.float32) * 4)
arg_t0 = (x, y, z, w)
arg_t1 = (w, y, z, w)
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
args_d0 = {"x": x, "y": y}
args_d1 = {"x": x, "y": y}
forward_net = TestNet()
forward_net(arg_t0, z, arg_l0, w, 6, args_d0)
forward_net(arg_t1, z, arg_l1, x, 6, args_d1)
grad_net = GradNet(forward_net)
with pytest.raises(TypeError) as err:
grad_net(arg_t0, z, arg_l0, w, 6, args_d0)
assert "For 'graph mode', the 0th arg" in str(err.value)