diff --git a/mindspore/ccsrc/backend/common/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/common/pass/convert_tuple_input_to_dynamic_input.cc index 0154438e063..62b683041ee 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_tuple_input_to_dynamic_input.cc @@ -61,6 +61,7 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) { return nullptr; } + bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut); std::vector plant_inputs; std::vector dyn_input_sizes; plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); @@ -68,7 +69,8 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C for (size_t i = 0; i < input_num; ++i) { auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i); MS_EXCEPTION_IF_NULL(input_node); - if (common::AnfAlgo::IsTupleOutput(input_node)) { + bool skip = (is_bprop_cut && input_node->abstract()->isa()); + if (common::AnfAlgo::IsTupleOutput(input_node) && !skip) { (void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs)); } else { dyn_input_sizes.push_back(-1); diff --git a/mindspore/ccsrc/backend/common/session/session_basic.cc b/mindspore/ccsrc/backend/common/session/session_basic.cc index 4ffad828bb5..79bab00dfd9 100644 --- a/mindspore/ccsrc/backend/common/session/session_basic.cc +++ b/mindspore/ccsrc/backend/common/session/session_basic.cc @@ -558,6 +558,40 @@ bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr &pa [](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); }); } #endif + +void IterateFindTensor(std::vector *msTensors, const VectorRef &ref_list) { + for (size_t i = 0; i < ref_list.size(); ++i) { + if (utils::isa(ref_list[i])) { + auto tensor_ptr = utils::cast>(ref_list[i]); + MS_EXCEPTION_IF_NULL(tensor_ptr); + msTensors->emplace_back(tensor_ptr); + } else if (utils::isa(ref_list[i])) { + auto ref_iter = utils::cast(ref_list[i]); + IterateFindTensor(msTensors, ref_iter); + } else if (utils::isa(ref_list[i])) { + auto csr_tensor = utils::cast(ref_list[i]); + MS_EXCEPTION_IF_NULL(csr_tensor); + msTensors->emplace_back(csr_tensor); + } else { + MS_LOG(EXCEPTION) << "The output is not a tensor/sparse tensor"; + } + } +} + +std::vector TransformVectorRefToMultiValue(const VectorRef &base_ref) { + std::vector msTensors; + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + IterateFindTensor(&msTensors, ref_list); + } else if (utils::isa(base_ref)) { + auto tensor_ptr = utils::cast>(base_ref); + MS_EXCEPTION_IF_NULL(tensor_ptr); + msTensors.emplace_back(tensor_ptr); + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + return msTensors; +} } // namespace GraphId SessionBasic::graph_sum_ = 0; @@ -1566,14 +1600,16 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op MS_EXCEPTION_IF_NULL(op_output_map); MS_EXCEPTION_IF_NULL(graph_output_info); MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs); - auto output_tensors = TransformVectorRefToMultiTensor(op_outputs); - if (output_tensors.size() > op_outputs.size()) { + auto output_values = TransformVectorRefToMultiValue(op_outputs); + if (output_values.size() > op_outputs.size()) { MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString(); } size_t out_index = 0; - for (const auto &output_tensor : output_tensors) { + for (const auto &output_value : output_values) { auto kernel_with_index = make_pair(kernel, out_index++); - if (ref_count.find(kernel_with_index) != ref_count.end()) { + auto output_tensor = output_value->cast(); + bool value_is_tensor = (output_tensor != nullptr); + if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) { (*op_output_map)[kernel_with_index] = output_tensor; } const auto &iter = graph_output_info->output_indexes.find(kernel_with_index); @@ -1597,8 +1633,10 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op cur_vector_ref = &utils::cast(base_ref); } BaseRef &tensor_ref = (*const_cast(cur_vector_ref))[ref_indexes.at(n)]; - tensor_ref = output_tensor; - graph_output_info->graph_output_tensors.emplace_back(output_tensor); + tensor_ref = output_value; + if (value_is_tensor) { + graph_output_info->graph_output_tensors.emplace_back(output_tensor); + } } } } diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index bb89ec8e03d..1e35cbfebd3 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -762,8 +762,12 @@ void GetControlOpInput(const std::shared_ptr &graph_compiler, con MS_EXCEPTION_IF_NULL(real_input); ValuePtr value = nullptr; if (!real_input->isa()) { - value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, graph_inputs, - input_tensor_info, back_index); + if (real_input->abstract() != nullptr && real_input->abstract()->isa()) { + value = TensorListToSparseTensor(real_input->abstract(), graph_inputs); + } else { + value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, + graph_inputs, input_tensor_info, back_index); + } MS_EXCEPTION_IF_NULL(value); ++back_index; } else { @@ -794,9 +798,9 @@ void GetControlOpInput(const std::shared_ptr &graph_compiler, con } } -void ConvertPyObjectToTensor(const py::object &input_object, std::vector *tensors) { +void ConvertPyObjectToTensor(const py::object &input_object, std::vector *tensors) { MS_EXCEPTION_IF_NULL(tensors); - tensor::TensorPtr tensor_ptr = nullptr; + ValuePtr tensor_ptr = nullptr; if (py::isinstance(input_object)) { tensor_ptr = py::cast(input_object); } else if (py::isinstance(input_object)) { @@ -816,6 +820,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector(input_object)) { + tensor_ptr = py::cast(input_object); + } else if (py::isinstance(input_object)) { + tensor_ptr = py::cast(input_object); } else { MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << "."; } @@ -860,10 +868,10 @@ void RunControlOperator(const std::shared_ptr &graph_compiler, co if (utils::isa(out)) { PyObjectRef py_ref = utils::cast(out); auto out_py_tuple = py_ref.object_; - std::vector output_tensors; + std::vector output_tensors; ConvertPyObjectToTensor(out_py_tuple, &output_tensors); (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_), - [](tensor::TensorPtr &tensor) { return std::move(tensor); }); + [](ValuePtr &tensor) { return std::move(tensor); }); } } } diff --git a/mindspore/ccsrc/include/common/utils/convert_utils.h b/mindspore/ccsrc/include/common/utils/convert_utils.h index f9fb8381606..bea03f82642 100644 --- a/mindspore/ccsrc/include/common/utils/convert_utils.h +++ b/mindspore/ccsrc/include/common/utils/convert_utils.h @@ -82,6 +82,9 @@ COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value); COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple); COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode); + +COMMON_EXPORT tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse, + const tensor::TensorPtrList &tensor_list); } // namespace mindspore #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index e84da9abe58..18eb3f61c4a 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -74,6 +74,8 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description"; const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; +const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor"; +const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor"; const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script"; const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids"; const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance"; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index c2e88f7eb6a..e9b9ba642ac 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -63,6 +63,12 @@ void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_ar if (py::isinstance(input_args[i])) { (*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]); + } else if (py::isinstance(input_args[i])) { + (*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, + parse::PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR, input_args[i]); + } else if (py::isinstance(input_args[i])) { + (*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, + parse::PYTHON_MOD_CONVERT_TO_MS_COOTENSOR, input_args[i]); } else if (py::isinstance(input_args[i])) { auto tuple_inp_arg = py::cast(input_args[i]); py::tuple convert_tuple_arg(tuple_inp_arg.size()); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index b4e9b874379..232b8fd19e3 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -348,4 +348,52 @@ bool IsAKGSparseOP(const AnfNodePtr &cnode) { prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM}; return IsOneOfPrimitiveCNode(cnode, prims); } + +namespace { +ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) { + ShapeVector shape; + if (index >= tensor_list.size()) { + MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size(); + return shape; + } + + auto converter = [](tensor::TensorPtr tensorptr) { + MS_EXCEPTION_IF_NULL(tensorptr); + if (tensorptr->DataDim() != 0) { + MS_LOG(EXCEPTION) << "Element must be scalar!"; + } + tensorptr->data_sync(false); + return *(static_cast(tensorptr->data_c())); + }; + std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter); + if (shape.empty()) { + MS_LOG(ERROR) << "ShapeVector is empty!"; + } + return shape; +} +tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) { + tensor::TensorPtr indptr = utils::cast(tensor_list[tensor::CSRTensor::kIndptrIdx]); + tensor::TensorPtr indices = utils::cast(tensor_list[tensor::CSRTensor::kIndicesIdx]); + tensor::TensorPtr values = utils::cast(tensor_list[tensor::CSRTensor::kValuesIdx]); + ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx); + auto csr_tensor_ptr = std::make_shared(indptr, indices, values, shape); + return csr_tensor_ptr; +} + +tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) { + tensor::TensorPtr indices = utils::cast(tensor_list[tensor::COOTensor::kIndicesIdx]); + tensor::TensorPtr values = utils::cast(tensor_list[tensor::COOTensor::kValuesIdx]); + ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx); + auto coo_tensor_ptr = std::make_shared(indices, values, shape); + return coo_tensor_ptr; +} +} // namespace + +tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse, + const tensor::TensorPtrList &tensor_list) { + if (abs_sparse->isa()) { + return TensorListToCOOTensor(tensor_list); + } + return TensorListToCSRTensor(tensor_list); +} } // namespace mindspore diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index d205c146bfe..478fc405970 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -1579,6 +1579,8 @@ AbstractBasePtr AbstractCOOTensor::Broaden() const { return std::make_shared(ElementsBroaden()); } +AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); } + std::string AbstractCOOTensor::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" @@ -1630,6 +1632,8 @@ AbstractBasePtr AbstractCSRTensor::Broaden() const { return std::make_shared(ElementsBroaden()); } +AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); } + std::string AbstractCSRTensor::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 3e4d388dbb5..d91c9012fcd 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -1461,6 +1461,7 @@ class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor { TypePtr BuildType() const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; + AbstractBasePtr PartialBroaden() const override; std::string ToString() const override; static constexpr size_t kIndicesIdx = 0; @@ -1489,6 +1490,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor { TypePtr BuildType() const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; + AbstractBasePtr PartialBroaden() const override; std::string ToString() const override; static constexpr size_t kIndptrIdx = 0; diff --git a/mindspore/python/mindspore/_extends/parse/__init__.py b/mindspore/python/mindspore/_extends/parse/__init__.py index e51cde38935..4f6a9431df2 100644 --- a/mindspore/python/mindspore/_extends/parse/__init__.py +++ b/mindspore/python/mindspore/_extends/parse/__init__.py @@ -24,7 +24,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type, get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, - is_class_type, check_obj_bool, python_isinstance, ms_isinstance) + is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor, + convert_to_ms_cootensor) __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope', 'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol', @@ -33,4 +34,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge 'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name', 'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol', 'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name', - 'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance'] + 'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor', + 'convert_to_ms_cootensor'] diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index ad0ab413c9c..11813702659 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -31,7 +31,7 @@ import numpy import asttokens import astunparse -from mindspore import Tensor +from mindspore import Tensor, CSRTensor, COOTensor from mindspore import log as logger from mindspore import nn from mindspore import ops @@ -497,6 +497,16 @@ def convert_to_ms_tensor(data): return Tensor(data) +def convert_to_ms_csrtensor(data): + """Convert C++ csrtensor to mindspore csrtensor.""" + return CSRTensor(csr_tensor=data) + + +def convert_to_ms_cootensor(data): + """Convert C++ cootensor to mindspore cootensor.""" + return COOTensor(coo_tensor=data) + + def get_object_description(obj, fname, fline): """Return method or funcition description for error report, include location, class name, etc.""" if isinstance(obj, types.MethodType): diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py index 7736b541c20..2433df84825 100644 --- a/tests/st/scipy_st/sparse/test_linalg.py +++ b/tests/st/scipy_st/sparse/test_linalg.py @@ -243,7 +243,8 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b): @pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8)]) +@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8), + ('CSRTensor', onp.float32, 1e-5)]) @pytest.mark.parametrize('a, b, grad_a, grad_b', [ ([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143], [0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368], @@ -278,6 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b): Description: test cases for grad implementation of cg in pynative mode Expectation: the result match expectation """ + if tensor_type == "CSRTensor" and get_platform() != "linux": + return context.set_context(mode=context.PYNATIVE_MODE) a = to_tensor((a, tensor_type), dtype)