support sparse userdefined bprop

This commit is contained in:
yanglf1121 2022-07-15 16:18:32 +08:00
parent f0142dce53
commit 027fac9b3c
12 changed files with 145 additions and 17 deletions

View File

@ -61,6 +61,7 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) {
return nullptr;
}
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
std::vector<AnfNodePtr> plant_inputs;
std::vector<int64_t> dyn_input_sizes;
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
@ -68,7 +69,8 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
if (common::AnfAlgo::IsTupleOutput(input_node)) {
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
if (common::AnfAlgo::IsTupleOutput(input_node) && !skip) {
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
} else {
dyn_input_sizes.push_back(-1);

View File

@ -558,6 +558,40 @@ bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr &pa
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
}
#endif
void IterateFindTensor(std::vector<ValuePtr> *msTensors, const VectorRef &ref_list) {
for (size_t i = 0; i < ref_list.size(); ++i) {
if (utils::isa<tensor::TensorPtr>(ref_list[i])) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
MS_EXCEPTION_IF_NULL(tensor_ptr);
msTensors->emplace_back(tensor_ptr);
} else if (utils::isa<VectorRef>(ref_list[i])) {
auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
IterateFindTensor(msTensors, ref_iter);
} else if (utils::isa<tensor::CSRTensorPtr>(ref_list[i])) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(ref_list[i]);
MS_EXCEPTION_IF_NULL(csr_tensor);
msTensors->emplace_back(csr_tensor);
} else {
MS_LOG(EXCEPTION) << "The output is not a tensor/sparse tensor";
}
}
}
std::vector<ValuePtr> TransformVectorRefToMultiValue(const VectorRef &base_ref) {
std::vector<ValuePtr> msTensors;
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
IterateFindTensor(&msTensors, ref_list);
} else if (utils::isa<tensor::Tensor>(base_ref)) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
MS_EXCEPTION_IF_NULL(tensor_ptr);
msTensors.emplace_back(tensor_ptr);
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
return msTensors;
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -1566,14 +1600,16 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(graph_output_info);
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
auto output_values = TransformVectorRefToMultiValue(op_outputs);
if (output_values.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
for (const auto &output_value : output_values) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
auto output_tensor = output_value->cast<tensor::TensorPtr>();
bool value_is_tensor = (output_tensor != nullptr);
if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
@ -1597,8 +1633,10 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
tensor_ref = output_value;
if (value_is_tensor) {
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
}
}
}
}

View File

@ -762,8 +762,12 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
MS_EXCEPTION_IF_NULL(real_input);
ValuePtr value = nullptr;
if (!real_input->isa<ValueNode>()) {
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, graph_inputs,
input_tensor_info, back_index);
if (real_input->abstract() != nullptr && real_input->abstract()->isa<abstract::AbstractSparseTensor>()) {
value = TensorListToSparseTensor(real_input->abstract(), graph_inputs);
} else {
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
graph_inputs, input_tensor_info, back_index);
}
MS_EXCEPTION_IF_NULL(value);
++back_index;
} else {
@ -794,9 +798,9 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
}
}
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
tensor::TensorPtr tensor_ptr = nullptr;
ValuePtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
@ -816,6 +820,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor:
ConvertPyObjectToTensor(tuple_inputs[i], tensors);
}
return;
} else if (py::isinstance<tensor::CSRTensor>(input_object)) {
tensor_ptr = py::cast<tensor::CSRTensorPtr>(input_object);
} else if (py::isinstance<tensor::COOTensor>(input_object)) {
tensor_ptr = py::cast<tensor::COOTensorPtr>(input_object);
} else {
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
}
@ -860,10 +868,10 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
if (utils::isa<PyObjectRef>(out)) {
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
auto out_py_tuple = py_ref.object_;
std::vector<tensor::TensorPtr> output_tensors;
std::vector<ValuePtr> output_tensors;
ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
[](ValuePtr &tensor) { return std::move(tensor); });
}
}
}

View File

@ -82,6 +82,9 @@ COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
COMMON_EXPORT tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
const tensor::TensorPtrList &tensor_list);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_

View File

@ -74,6 +74,8 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor";
const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";

View File

@ -63,6 +63,12 @@ void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_ar
if (py::isinstance<tensor::Tensor>(input_args[i])) {
(*convert_args)[i] =
python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]);
} else if (py::isinstance<tensor::CSRTensor>(input_args[i])) {
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR, input_args[i]);
} else if (py::isinstance<tensor::COOTensor>(input_args[i])) {
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_COOTENSOR, input_args[i]);
} else if (py::isinstance<py::tuple>(input_args[i])) {
auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]);
py::tuple convert_tuple_arg(tuple_inp_arg.size());

View File

@ -348,4 +348,52 @@ bool IsAKGSparseOP(const AnfNodePtr &cnode) {
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
return IsOneOfPrimitiveCNode(cnode, prims);
}
namespace {
ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) {
ShapeVector shape;
if (index >= tensor_list.size()) {
MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size();
return shape;
}
auto converter = [](tensor::TensorPtr tensorptr) {
MS_EXCEPTION_IF_NULL(tensorptr);
if (tensorptr->DataDim() != 0) {
MS_LOG(EXCEPTION) << "Element must be scalar!";
}
tensorptr->data_sync(false);
return *(static_cast<int64_t *>(tensorptr->data_c()));
};
std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter);
if (shape.empty()) {
MS_LOG(ERROR) << "ShapeVector is empty!";
}
return shape;
}
tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) {
tensor::TensorPtr indptr = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndptrIdx]);
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndicesIdx]);
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kValuesIdx]);
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx);
auto csr_tensor_ptr = std::make_shared<tensor::CSRTensor>(indptr, indices, values, shape);
return csr_tensor_ptr;
}
tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) {
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kIndicesIdx]);
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kValuesIdx]);
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx);
auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
return coo_tensor_ptr;
}
} // namespace
tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
const tensor::TensorPtrList &tensor_list) {
if (abs_sparse->isa<abstract::AbstractCOOTensor>()) {
return TensorListToCOOTensor(tensor_list);
}
return TensorListToCSRTensor(tensor_list);
}
} // namespace mindspore

View File

@ -1579,6 +1579,8 @@ AbstractBasePtr AbstractCOOTensor::Broaden() const {
return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
}
AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); }
std::string AbstractCOOTensor::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("
@ -1630,6 +1632,8 @@ AbstractBasePtr AbstractCSRTensor::Broaden() const {
return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
}
AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); }
std::string AbstractCSRTensor::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("

View File

@ -1461,6 +1461,7 @@ class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override;
static constexpr size_t kIndicesIdx = 0;
@ -1489,6 +1490,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override;
static constexpr size_t kIndptrIdx = 0;

View File

@ -24,7 +24,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance)
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -33,4 +34,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance']
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
'convert_to_ms_cootensor']

View File

@ -31,7 +31,7 @@ import numpy
import asttokens
import astunparse
from mindspore import Tensor
from mindspore import Tensor, CSRTensor, COOTensor
from mindspore import log as logger
from mindspore import nn
from mindspore import ops
@ -497,6 +497,16 @@ def convert_to_ms_tensor(data):
return Tensor(data)
def convert_to_ms_csrtensor(data):
"""Convert C++ csrtensor to mindspore csrtensor."""
return CSRTensor(csr_tensor=data)
def convert_to_ms_cootensor(data):
"""Convert C++ cootensor to mindspore cootensor."""
return COOTensor(coo_tensor=data)
def get_object_description(obj, fname, fline):
"""Return method or funcition description for error report, include location, class name, etc."""
if isinstance(obj, types.MethodType):

View File

@ -243,7 +243,8 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8)])
@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8),
('CSRTensor', onp.float32, 1e-5)])
@pytest.mark.parametrize('a, b, grad_a, grad_b', [
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
@ -278,6 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
Description: test cases for grad implementation of cg in pynative mode
Expectation: the result match expectation
"""
if tensor_type == "CSRTensor" and get_platform() != "linux":
return
context.set_context(mode=context.PYNATIVE_MODE)
a = to_tensor((a, tensor_type), dtype)