forked from mindspore-Ecosystem/mindspore
eager mode sparse
This commit is contained in:
parent
8dec74908a
commit
9927e6eb5c
|
@ -132,7 +132,17 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
|
|||
// Inputs: index, branch
|
||||
const std::string op_name = primitive->name();
|
||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
(void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto &input_shape = index->shape()->shape();
|
||||
if (input_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto dtype = index->element()->BuildType();
|
||||
if (dtype->type_id() != kInt32->type_id()) {
|
||||
MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString();
|
||||
}
|
||||
|
||||
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
const size_t maximum_layer_num = 1000;
|
||||
|
|
|
@ -145,9 +145,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
|
||||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
|
||||
"env_getitem"};
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch && prim_ != prim::kPrimEnvSetItem &&
|
||||
prim_ != prim::kPrimEnvGetItem) {
|
||||
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||
|
@ -167,17 +169,23 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
|
||||
auto &func = do_signature->function();
|
||||
if (func->isa<Primitive>()) {
|
||||
auto sig_prim = func->cast<PrimitivePtr>();
|
||||
if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = AbstractEval(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim_);
|
||||
auto out_node = dyn_cast<CNode>(out_conf->node());
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
|
@ -447,6 +455,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic["shape"] = py::none();
|
||||
dic["dtype"] = abs_base->BuildType();
|
||||
dic["value"] = py::none();
|
||||
} else if (abs_base->isa<AbstractUndetermined>()) {
|
||||
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
|
||||
dic["shape"] = py::none();
|
||||
dic["dtype"] = arg->BuildType();
|
||||
dic["value"] = py::none();
|
||||
} else {
|
||||
auto value = abs_base->BuildValue();
|
||||
if ((*value == *kAnyValue)) {
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "ir/param_value.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
py::object BuiltinsToPyData(const Any &value);
|
||||
|
@ -404,6 +405,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
auto abstract_none = std::make_shared<abstract::AbstractNone>();
|
||||
return abstract_none;
|
||||
} else {
|
||||
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (enable_sparse) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,6 +101,7 @@ list_type = typing.List
|
|||
tuple_type = typing.Tuple
|
||||
index_slices = typing.IndexedSlicesType()
|
||||
sparse_tensor = typing.SparseTensorType()
|
||||
undetermined = typing.UndeterminedType()
|
||||
|
||||
number_type = (int8,
|
||||
int16,
|
||||
|
|
|
@ -290,7 +290,19 @@ class IndexedSlices:
|
|||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
"Init IndexedSlices"
|
||||
self.__indices = indices
|
||||
self.__values = values
|
||||
self.__dense_shape = dense_shape
|
||||
|
||||
def indices(self):
|
||||
return self.__indices
|
||||
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
||||
|
||||
class SparseTensor:
|
||||
|
@ -331,4 +343,16 @@ class SparseTensor:
|
|||
"""
|
||||
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
"Init SparseTensor"
|
||||
self.__indices = indices
|
||||
self.__values = values
|
||||
self.__dense_shape = dense_shape
|
||||
|
||||
def indices(self):
|
||||
return self.__indices
|
||||
|
||||
def values(self):
|
||||
return self.__values
|
||||
|
||||
def dense_shape(self):
|
||||
return self.__dense_shape
|
||||
|
|
|
@ -814,9 +814,13 @@ class AddN(PrimitiveWithInfer):
|
|||
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
args = {}
|
||||
contains_undetermined = False
|
||||
for i, dtype in enumerate(inputs):
|
||||
args[f"inputs[{i}]"] = dtype
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
|
||||
if dtype == mstype.undetermined:
|
||||
contains_undetermined = True
|
||||
if not contains_undetermined:
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
|
||||
return inputs[0]
|
||||
|
||||
def infer_value(self, inputs):
|
||||
|
|
|
@ -398,7 +398,7 @@ def test_switch_layer():
|
|||
ret = F.switch_layer(index, self.layers)(x) * self.z3
|
||||
return ret
|
||||
|
||||
index = Tensor(0)
|
||||
index = Tensor(0, dtype=mstype.int32)
|
||||
net = SwitchLayerCell()
|
||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||
|
@ -436,7 +436,7 @@ def test_index_to_switch_layer():
|
|||
ret = self.layers[index](x) * self.z3
|
||||
return ret
|
||||
|
||||
index = Tensor(0)
|
||||
index = Tensor(0, dtype=mstype.int32)
|
||||
net = SwitchLayerCell()
|
||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
@File : test_sparse_pynative.py
|
||||
@Author:
|
||||
@Date : 2020-08-04
|
||||
@Desc : test mindspore sparse pynative
|
||||
"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, IndexedSlices, SparseTensor
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
|
||||
|
||||
|
||||
grad_all = C.GradOperation('get_all', get_all=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
def construct(self, *args):
|
||||
grad = grad_all(self.network)(*args)
|
||||
return grad
|
||||
|
||||
|
||||
def test_indexed_slices_attr():
|
||||
class IndexedSlicesGetAttr(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super(IndexedSlicesGetAttr, self).__init__()
|
||||
self.dense_shape = dense_shape
|
||||
def construct(self, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
indices = Tensor([0])
|
||||
values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
IndexedSlicesGetAttr((3, 2))(indices, values)
|
||||
GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_attr():
|
||||
class SparseTensorGetAttr(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseTensorGetAttr, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
x = SparseTensor(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
SparseTensorGetAttr()(indices, values)
|
||||
GradWrap(SparseTensorGetAttr())(indices, values)
|
Loading…
Reference in New Issue