forked from mindspore-Ecosystem/mindspore
support sparse userdefined bprop
This commit is contained in:
parent
f0142dce53
commit
027fac9b3c
|
@ -61,6 +61,7 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
|
||||||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) {
|
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
|
||||||
std::vector<AnfNodePtr> plant_inputs;
|
std::vector<AnfNodePtr> plant_inputs;
|
||||||
std::vector<int64_t> dyn_input_sizes;
|
std::vector<int64_t> dyn_input_sizes;
|
||||||
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||||
|
@ -68,7 +69,8 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
|
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
|
||||||
|
if (common::AnfAlgo::IsTupleOutput(input_node) && !skip) {
|
||||||
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||||
} else {
|
} else {
|
||||||
dyn_input_sizes.push_back(-1);
|
dyn_input_sizes.push_back(-1);
|
||||||
|
|
|
@ -558,6 +558,40 @@ bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr &pa
|
||||||
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
|
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void IterateFindTensor(std::vector<ValuePtr> *msTensors, const VectorRef &ref_list) {
|
||||||
|
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||||
|
if (utils::isa<tensor::TensorPtr>(ref_list[i])) {
|
||||||
|
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||||
|
msTensors->emplace_back(tensor_ptr);
|
||||||
|
} else if (utils::isa<VectorRef>(ref_list[i])) {
|
||||||
|
auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
|
||||||
|
IterateFindTensor(msTensors, ref_iter);
|
||||||
|
} else if (utils::isa<tensor::CSRTensorPtr>(ref_list[i])) {
|
||||||
|
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(ref_list[i]);
|
||||||
|
MS_EXCEPTION_IF_NULL(csr_tensor);
|
||||||
|
msTensors->emplace_back(csr_tensor);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The output is not a tensor/sparse tensor";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ValuePtr> TransformVectorRefToMultiValue(const VectorRef &base_ref) {
|
||||||
|
std::vector<ValuePtr> msTensors;
|
||||||
|
if (utils::isa<VectorRef>(base_ref)) {
|
||||||
|
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||||
|
IterateFindTensor(&msTensors, ref_list);
|
||||||
|
} else if (utils::isa<tensor::Tensor>(base_ref)) {
|
||||||
|
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||||
|
msTensors.emplace_back(tensor_ptr);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||||
|
}
|
||||||
|
return msTensors;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
GraphId SessionBasic::graph_sum_ = 0;
|
GraphId SessionBasic::graph_sum_ = 0;
|
||||||
|
@ -1566,14 +1600,16 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
|
||||||
MS_EXCEPTION_IF_NULL(op_output_map);
|
MS_EXCEPTION_IF_NULL(op_output_map);
|
||||||
MS_EXCEPTION_IF_NULL(graph_output_info);
|
MS_EXCEPTION_IF_NULL(graph_output_info);
|
||||||
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
|
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
|
||||||
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
|
auto output_values = TransformVectorRefToMultiValue(op_outputs);
|
||||||
if (output_tensors.size() > op_outputs.size()) {
|
if (output_values.size() > op_outputs.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
|
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
|
||||||
}
|
}
|
||||||
size_t out_index = 0;
|
size_t out_index = 0;
|
||||||
for (const auto &output_tensor : output_tensors) {
|
for (const auto &output_value : output_values) {
|
||||||
auto kernel_with_index = make_pair(kernel, out_index++);
|
auto kernel_with_index = make_pair(kernel, out_index++);
|
||||||
if (ref_count.find(kernel_with_index) != ref_count.end()) {
|
auto output_tensor = output_value->cast<tensor::TensorPtr>();
|
||||||
|
bool value_is_tensor = (output_tensor != nullptr);
|
||||||
|
if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) {
|
||||||
(*op_output_map)[kernel_with_index] = output_tensor;
|
(*op_output_map)[kernel_with_index] = output_tensor;
|
||||||
}
|
}
|
||||||
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
|
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
|
||||||
|
@ -1597,8 +1633,10 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
|
||||||
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
|
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
|
||||||
}
|
}
|
||||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||||
tensor_ref = output_tensor;
|
tensor_ref = output_value;
|
||||||
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
|
if (value_is_tensor) {
|
||||||
|
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -762,8 +762,12 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
|
||||||
MS_EXCEPTION_IF_NULL(real_input);
|
MS_EXCEPTION_IF_NULL(real_input);
|
||||||
ValuePtr value = nullptr;
|
ValuePtr value = nullptr;
|
||||||
if (!real_input->isa<ValueNode>()) {
|
if (!real_input->isa<ValueNode>()) {
|
||||||
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, graph_inputs,
|
if (real_input->abstract() != nullptr && real_input->abstract()->isa<abstract::AbstractSparseTensor>()) {
|
||||||
input_tensor_info, back_index);
|
value = TensorListToSparseTensor(real_input->abstract(), graph_inputs);
|
||||||
|
} else {
|
||||||
|
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
|
||||||
|
graph_inputs, input_tensor_info, back_index);
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
++back_index;
|
++back_index;
|
||||||
} else {
|
} else {
|
||||||
|
@ -794,9 +798,9 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePtr> *tensors) {
|
||||||
MS_EXCEPTION_IF_NULL(tensors);
|
MS_EXCEPTION_IF_NULL(tensors);
|
||||||
tensor::TensorPtr tensor_ptr = nullptr;
|
ValuePtr tensor_ptr = nullptr;
|
||||||
if (py::isinstance<tensor::Tensor>(input_object)) {
|
if (py::isinstance<tensor::Tensor>(input_object)) {
|
||||||
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
|
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
|
||||||
} else if (py::isinstance<py::float_>(input_object)) {
|
} else if (py::isinstance<py::float_>(input_object)) {
|
||||||
|
@ -816,6 +820,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor:
|
||||||
ConvertPyObjectToTensor(tuple_inputs[i], tensors);
|
ConvertPyObjectToTensor(tuple_inputs[i], tensors);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
} else if (py::isinstance<tensor::CSRTensor>(input_object)) {
|
||||||
|
tensor_ptr = py::cast<tensor::CSRTensorPtr>(input_object);
|
||||||
|
} else if (py::isinstance<tensor::COOTensor>(input_object)) {
|
||||||
|
tensor_ptr = py::cast<tensor::COOTensorPtr>(input_object);
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
|
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
|
||||||
}
|
}
|
||||||
|
@ -860,10 +868,10 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
|
||||||
if (utils::isa<PyObjectRef>(out)) {
|
if (utils::isa<PyObjectRef>(out)) {
|
||||||
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
|
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
|
||||||
auto out_py_tuple = py_ref.object_;
|
auto out_py_tuple = py_ref.object_;
|
||||||
std::vector<tensor::TensorPtr> output_tensors;
|
std::vector<ValuePtr> output_tensors;
|
||||||
ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
|
ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
|
||||||
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
|
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
|
||||||
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
|
[](ValuePtr &tensor) { return std::move(tensor); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,6 +82,9 @@ COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
|
||||||
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
|
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
|
||||||
|
|
||||||
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
|
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
|
||||||
|
|
||||||
|
COMMON_EXPORT tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
|
||||||
|
const tensor::TensorPtrList &tensor_list);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_
|
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_
|
||||||
|
|
|
@ -74,6 +74,8 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||||
|
const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor";
|
||||||
|
const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor";
|
||||||
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
||||||
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
|
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
|
||||||
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";
|
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";
|
||||||
|
|
|
@ -63,6 +63,12 @@ void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_ar
|
||||||
if (py::isinstance<tensor::Tensor>(input_args[i])) {
|
if (py::isinstance<tensor::Tensor>(input_args[i])) {
|
||||||
(*convert_args)[i] =
|
(*convert_args)[i] =
|
||||||
python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]);
|
python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]);
|
||||||
|
} else if (py::isinstance<tensor::CSRTensor>(input_args[i])) {
|
||||||
|
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
||||||
|
parse::PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR, input_args[i]);
|
||||||
|
} else if (py::isinstance<tensor::COOTensor>(input_args[i])) {
|
||||||
|
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
||||||
|
parse::PYTHON_MOD_CONVERT_TO_MS_COOTENSOR, input_args[i]);
|
||||||
} else if (py::isinstance<py::tuple>(input_args[i])) {
|
} else if (py::isinstance<py::tuple>(input_args[i])) {
|
||||||
auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]);
|
auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]);
|
||||||
py::tuple convert_tuple_arg(tuple_inp_arg.size());
|
py::tuple convert_tuple_arg(tuple_inp_arg.size());
|
||||||
|
|
|
@ -348,4 +348,52 @@ bool IsAKGSparseOP(const AnfNodePtr &cnode) {
|
||||||
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
|
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
|
||||||
return IsOneOfPrimitiveCNode(cnode, prims);
|
return IsOneOfPrimitiveCNode(cnode, prims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) {
|
||||||
|
ShapeVector shape;
|
||||||
|
if (index >= tensor_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size();
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto converter = [](tensor::TensorPtr tensorptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tensorptr);
|
||||||
|
if (tensorptr->DataDim() != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Element must be scalar!";
|
||||||
|
}
|
||||||
|
tensorptr->data_sync(false);
|
||||||
|
return *(static_cast<int64_t *>(tensorptr->data_c()));
|
||||||
|
};
|
||||||
|
std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter);
|
||||||
|
if (shape.empty()) {
|
||||||
|
MS_LOG(ERROR) << "ShapeVector is empty!";
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) {
|
||||||
|
tensor::TensorPtr indptr = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndptrIdx]);
|
||||||
|
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndicesIdx]);
|
||||||
|
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kValuesIdx]);
|
||||||
|
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx);
|
||||||
|
auto csr_tensor_ptr = std::make_shared<tensor::CSRTensor>(indptr, indices, values, shape);
|
||||||
|
return csr_tensor_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) {
|
||||||
|
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kIndicesIdx]);
|
||||||
|
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kValuesIdx]);
|
||||||
|
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx);
|
||||||
|
auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
|
||||||
|
return coo_tensor_ptr;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
|
||||||
|
const tensor::TensorPtrList &tensor_list) {
|
||||||
|
if (abs_sparse->isa<abstract::AbstractCOOTensor>()) {
|
||||||
|
return TensorListToCOOTensor(tensor_list);
|
||||||
|
}
|
||||||
|
return TensorListToCSRTensor(tensor_list);
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1579,6 +1579,8 @@ AbstractBasePtr AbstractCOOTensor::Broaden() const {
|
||||||
return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
|
return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); }
|
||||||
|
|
||||||
std::string AbstractCOOTensor::ToString() const {
|
std::string AbstractCOOTensor::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
|
@ -1630,6 +1632,8 @@ AbstractBasePtr AbstractCSRTensor::Broaden() const {
|
||||||
return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
|
return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); }
|
||||||
|
|
||||||
std::string AbstractCSRTensor::ToString() const {
|
std::string AbstractCSRTensor::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
|
|
|
@ -1461,6 +1461,7 @@ class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor {
|
||||||
TypePtr BuildType() const override;
|
TypePtr BuildType() const override;
|
||||||
AbstractBasePtr Clone() const override;
|
AbstractBasePtr Clone() const override;
|
||||||
AbstractBasePtr Broaden() const override;
|
AbstractBasePtr Broaden() const override;
|
||||||
|
AbstractBasePtr PartialBroaden() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
static constexpr size_t kIndicesIdx = 0;
|
static constexpr size_t kIndicesIdx = 0;
|
||||||
|
@ -1489,6 +1490,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor {
|
||||||
TypePtr BuildType() const override;
|
TypePtr BuildType() const override;
|
||||||
AbstractBasePtr Clone() const override;
|
AbstractBasePtr Clone() const override;
|
||||||
AbstractBasePtr Broaden() const override;
|
AbstractBasePtr Broaden() const override;
|
||||||
|
AbstractBasePtr PartialBroaden() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
static constexpr size_t kIndptrIdx = 0;
|
static constexpr size_t kIndptrIdx = 0;
|
||||||
|
|
|
@ -24,7 +24,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
||||||
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
||||||
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance)
|
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||||
|
convert_to_ms_cootensor)
|
||||||
|
|
||||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||||
|
@ -33,4 +34,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
||||||
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
||||||
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance']
|
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
|
||||||
|
'convert_to_ms_cootensor']
|
||||||
|
|
|
@ -31,7 +31,7 @@ import numpy
|
||||||
import asttokens
|
import asttokens
|
||||||
import astunparse
|
import astunparse
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor, CSRTensor, COOTensor
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore import ops
|
from mindspore import ops
|
||||||
|
@ -497,6 +497,16 @@ def convert_to_ms_tensor(data):
|
||||||
return Tensor(data)
|
return Tensor(data)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_ms_csrtensor(data):
|
||||||
|
"""Convert C++ csrtensor to mindspore csrtensor."""
|
||||||
|
return CSRTensor(csr_tensor=data)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_ms_cootensor(data):
|
||||||
|
"""Convert C++ cootensor to mindspore cootensor."""
|
||||||
|
return COOTensor(coo_tensor=data)
|
||||||
|
|
||||||
|
|
||||||
def get_object_description(obj, fname, fline):
|
def get_object_description(obj, fname, fline):
|
||||||
"""Return method or funcition description for error report, include location, class name, etc."""
|
"""Return method or funcition description for error report, include location, class name, etc."""
|
||||||
if isinstance(obj, types.MethodType):
|
if isinstance(obj, types.MethodType):
|
||||||
|
|
|
@ -243,7 +243,8 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8)])
|
@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8),
|
||||||
|
('CSRTensor', onp.float32, 1e-5)])
|
||||||
@pytest.mark.parametrize('a, b, grad_a, grad_b', [
|
@pytest.mark.parametrize('a, b, grad_a, grad_b', [
|
||||||
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
|
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
|
||||||
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
|
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
|
||||||
|
@ -278,6 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
||||||
Description: test cases for grad implementation of cg in pynative mode
|
Description: test cases for grad implementation of cg in pynative mode
|
||||||
Expectation: the result match expectation
|
Expectation: the result match expectation
|
||||||
"""
|
"""
|
||||||
|
if tensor_type == "CSRTensor" and get_platform() != "linux":
|
||||||
|
return
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
a = to_tensor((a, tensor_type), dtype)
|
a = to_tensor((a, tensor_type), dtype)
|
||||||
|
|
Loading…
Reference in New Issue