forked from mindspore-Ecosystem/mindspore
!8697 support the outermost layer network inputs are list or dict or scalar
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e45f19adc8
|
@ -115,7 +115,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
|
|||
if (!parse::ConvertData(arg.second, &converted)) {
|
||||
MS_LOG(EXCEPTION) << "GenerateKey convert arg failed";
|
||||
}
|
||||
args_spec.push_back(abstract::FromValue(converted, true));
|
||||
bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>();
|
||||
args_spec.push_back(abstract::FromValue(converted, broaden));
|
||||
}
|
||||
if (g_args_cache.count(args_spec) == 0) {
|
||||
static int64_t key = 0;
|
||||
|
|
|
@ -413,18 +413,54 @@ class _Executor:
|
|||
Str, the full phase of the cell.
|
||||
Bool, if the graph has been compiled before, return False, else return True.
|
||||
"""
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
self.phase_prefix = str(key[1])
|
||||
if 'export' in phase:
|
||||
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
|
||||
else:
|
||||
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
|
||||
from mindspore import nn
|
||||
|
||||
if phase in self.compile_cache.keys():
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
class InputsToAttrCell(nn.Cell):
|
||||
"""The cell that converts non-tensor inputs to attr."""
|
||||
|
||||
def __init__(self, net, args_names, non_tensor_inputs):
|
||||
super(InputsToAttrCell, self).__init__()
|
||||
self.net = net
|
||||
self.args_names = args_names
|
||||
self.non_tensor_inputs = non_tensor_inputs
|
||||
self.inputs_to_attr = True
|
||||
|
||||
def construct(self, *tensor_inputs):
|
||||
real_inputs = ()
|
||||
index = 0
|
||||
for i in args_names:
|
||||
if i in self.non_tensor_inputs.keys():
|
||||
real_inputs += (self.non_tensor_inputs[i],)
|
||||
else:
|
||||
real_inputs += (tensor_inputs[index],)
|
||||
index += 1
|
||||
return self.net(*real_inputs)
|
||||
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
if not hasattr(obj, "inputs_to_attr"):
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
self.phase_prefix = str(key[1])
|
||||
if 'export' in phase:
|
||||
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
|
||||
else:
|
||||
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
|
||||
|
||||
if phase in self.compile_cache.keys():
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
|
||||
if getattr(obj, "support_non_tensor_inputs", None):
|
||||
attrs = {}
|
||||
inputs = []
|
||||
for key, value in dic.items():
|
||||
if not isinstance(value, (Tensor, MetaTensor)):
|
||||
attrs[key] = value
|
||||
else:
|
||||
inputs.append(value)
|
||||
if attrs:
|
||||
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
|
||||
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
|
||||
|
||||
obj.check_names()
|
||||
_check_full_batch()
|
||||
|
|
|
@ -30,7 +30,7 @@ from ..ops.primitive import Primitive
|
|||
from ..ops.operations import HookBackward
|
||||
from ..ops.functional import cast
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..common.tensor import Tensor
|
||||
from ..common.tensor import Tensor, MetaTensor
|
||||
|
||||
|
||||
class Cell(Cell_):
|
||||
|
@ -104,6 +104,7 @@ class Cell(Cell_):
|
|||
self._already_run = False
|
||||
self.cell_type = None
|
||||
self._auto_parallel_compile_and_run = False
|
||||
self._support_non_tensor_inputs = False
|
||||
|
||||
@property
|
||||
def already_run(self):
|
||||
|
@ -119,6 +120,23 @@ class Cell(Cell_):
|
|||
self.__dict__ = dict_
|
||||
self._attr_synced = False
|
||||
|
||||
@property
|
||||
def support_non_tensor_inputs(self):
|
||||
"""
|
||||
Whether support non tensor inputs in cell `construct` method.
|
||||
This property only used in forward net, is not supported in grad net.
|
||||
"""
|
||||
return self._support_non_tensor_inputs
|
||||
|
||||
@support_non_tensor_inputs.setter
|
||||
def support_non_tensor_inputs(self, value):
|
||||
"""
|
||||
Set attr 'support_non_tensor_inputs'.
|
||||
"""
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
|
||||
self._support_non_tensor_inputs = value
|
||||
|
||||
@property
|
||||
def _cell_tag(self):
|
||||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||
|
@ -553,14 +571,19 @@ class Cell(Cell_):
|
|||
self._auto_parallel_compile_and_run = True
|
||||
self.compile(*inputs)
|
||||
|
||||
new_inputs = []
|
||||
for i in inputs:
|
||||
if isinstance(i, (Tensor, MetaTensor)):
|
||||
new_inputs.append(i)
|
||||
|
||||
if self._auto_parallel_mode:
|
||||
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
|
||||
if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag:
|
||||
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
|
||||
parallel_inputs_run = self._parallel_inputs_run
|
||||
else:
|
||||
parallel_inputs_run = inputs
|
||||
parallel_inputs_run = new_inputs
|
||||
return _executor(self, *parallel_inputs_run, phase=self.phase)
|
||||
return _executor(self, *inputs, phase=self.phase)
|
||||
return _executor(self, *new_inputs, phase=self.phase)
|
||||
|
||||
def auto_parallel_compile_and_run(self):
|
||||
return self._auto_parallel_compile_and_run
|
||||
|
|
|
@ -94,7 +94,7 @@ def restrict_int_index(data_shape, tuple_indexes):
|
|||
for i, index in enumerate(tuple_indexes):
|
||||
if isinstance(index, mstype.Int):
|
||||
if index < -data_shape[i] or index >= data_shape[i]:
|
||||
const_utils.raise_index_error("The index is out of the data's special dimension range.")
|
||||
raise_index_error("The index is out of the data's special dimension range.")
|
||||
elif index < 0:
|
||||
tuple_indexes_new += (tuple_indexes[i]+data_shape[i],)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test outermost net pass scalar tuple list dict"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_outermost_net_pass_scalar_tuple_list_dict():
|
||||
class TestNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestNet, self).__init__()
|
||||
self.support_non_tensor_inputs = True
|
||||
|
||||
def construct(self, tuple_a, z, list_m, w, s, dict_n):
|
||||
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.forward_net = net
|
||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||
self.grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, tuple_a, z, list_m, w, s, dict_n):
|
||||
return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n)
|
||||
|
||||
x = Tensor(np.ones((2, 2), np.float32))
|
||||
y = Tensor(np.ones((2, 2), np.float32) * 2)
|
||||
z = Tensor(np.ones((2, 2), np.float32) * 3)
|
||||
w = Tensor(np.ones((2, 2), np.float32) * 4)
|
||||
arg_t0 = (x, y, z, w)
|
||||
arg_t1 = (w, y, z, w)
|
||||
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
||||
args_d0 = {"x": x, "y": y}
|
||||
args_d1 = {"x": x, "y": y}
|
||||
forward_net = TestNet()
|
||||
forward_net(arg_t0, z, arg_l0, w, 6, args_d0)
|
||||
forward_net(arg_t1, z, arg_l1, x, 6, args_d1)
|
||||
|
||||
grad_net = GradNet(forward_net)
|
||||
with pytest.raises(TypeError) as err:
|
||||
grad_net(arg_t0, z, arg_l0, w, 6, args_d0)
|
||||
assert "For 'graph mode', the 0th arg" in str(err.value)
|
Loading…
Reference in New Issue