support sparse userdefined bprop

This commit is contained in:
yanglf1121 2022-07-15 16:18:32 +08:00
parent f0142dce53
commit 027fac9b3c
12 changed files with 145 additions and 17 deletions

View File

@ -61,6 +61,7 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) { common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimResizeNearestNeighborGrad)) {
return nullptr; return nullptr;
} }
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
std::vector<AnfNodePtr> plant_inputs; std::vector<AnfNodePtr> plant_inputs;
std::vector<int64_t> dyn_input_sizes; std::vector<int64_t> dyn_input_sizes;
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
@ -68,7 +69,8 @@ AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const C
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i); auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (common::AnfAlgo::IsTupleOutput(input_node)) { bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
if (common::AnfAlgo::IsTupleOutput(input_node) && !skip) {
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs)); (void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
} else { } else {
dyn_input_sizes.push_back(-1); dyn_input_sizes.push_back(-1);

View File

@ -558,6 +558,40 @@ bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr &pa
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); }); [](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
} }
#endif #endif
void IterateFindTensor(std::vector<ValuePtr> *msTensors, const VectorRef &ref_list) {
for (size_t i = 0; i < ref_list.size(); ++i) {
if (utils::isa<tensor::TensorPtr>(ref_list[i])) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
MS_EXCEPTION_IF_NULL(tensor_ptr);
msTensors->emplace_back(tensor_ptr);
} else if (utils::isa<VectorRef>(ref_list[i])) {
auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
IterateFindTensor(msTensors, ref_iter);
} else if (utils::isa<tensor::CSRTensorPtr>(ref_list[i])) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(ref_list[i]);
MS_EXCEPTION_IF_NULL(csr_tensor);
msTensors->emplace_back(csr_tensor);
} else {
MS_LOG(EXCEPTION) << "The output is not a tensor/sparse tensor";
}
}
}
std::vector<ValuePtr> TransformVectorRefToMultiValue(const VectorRef &base_ref) {
std::vector<ValuePtr> msTensors;
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
IterateFindTensor(&msTensors, ref_list);
} else if (utils::isa<tensor::Tensor>(base_ref)) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
MS_EXCEPTION_IF_NULL(tensor_ptr);
msTensors.emplace_back(tensor_ptr);
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
return msTensors;
}
} // namespace } // namespace
GraphId SessionBasic::graph_sum_ = 0; GraphId SessionBasic::graph_sum_ = 0;
@ -1566,14 +1600,16 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
MS_EXCEPTION_IF_NULL(op_output_map); MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(graph_output_info); MS_EXCEPTION_IF_NULL(graph_output_info);
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs); MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs); auto output_values = TransformVectorRefToMultiValue(op_outputs);
if (output_tensors.size() > op_outputs.size()) { if (output_values.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString(); MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
} }
size_t out_index = 0; size_t out_index = 0;
for (const auto &output_tensor : output_tensors) { for (const auto &output_value : output_values) {
auto kernel_with_index = make_pair(kernel, out_index++); auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) { auto output_tensor = output_value->cast<tensor::TensorPtr>();
bool value_is_tensor = (output_tensor != nullptr);
if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) {
(*op_output_map)[kernel_with_index] = output_tensor; (*op_output_map)[kernel_with_index] = output_tensor;
} }
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index); const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
@ -1597,8 +1633,10 @@ void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op
cur_vector_ref = &utils::cast<VectorRef>(base_ref); cur_vector_ref = &utils::cast<VectorRef>(base_ref);
} }
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)]; BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor; tensor_ref = output_value;
graph_output_info->graph_output_tensors.emplace_back(output_tensor); if (value_is_tensor) {
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
}
} }
} }
} }

View File

@ -762,8 +762,12 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);
ValuePtr value = nullptr; ValuePtr value = nullptr;
if (!real_input->isa<ValueNode>()) { if (!real_input->isa<ValueNode>()) {
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, graph_inputs, if (real_input->abstract() != nullptr && real_input->abstract()->isa<abstract::AbstractSparseTensor>()) {
input_tensor_info, back_index); value = TensorListToSparseTensor(real_input->abstract(), graph_inputs);
} else {
value = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
graph_inputs, input_tensor_info, back_index);
}
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
++back_index; ++back_index;
} else { } else {
@ -794,9 +798,9 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
} }
} }
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) { void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors); MS_EXCEPTION_IF_NULL(tensors);
tensor::TensorPtr tensor_ptr = nullptr; ValuePtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) { if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object); tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (py::isinstance<py::float_>(input_object)) { } else if (py::isinstance<py::float_>(input_object)) {
@ -816,6 +820,10 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor:
ConvertPyObjectToTensor(tuple_inputs[i], tensors); ConvertPyObjectToTensor(tuple_inputs[i], tensors);
} }
return; return;
} else if (py::isinstance<tensor::CSRTensor>(input_object)) {
tensor_ptr = py::cast<tensor::CSRTensorPtr>(input_object);
} else if (py::isinstance<tensor::COOTensor>(input_object)) {
tensor_ptr = py::cast<tensor::COOTensorPtr>(input_object);
} else { } else {
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << "."; MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
} }
@ -860,10 +868,10 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
if (utils::isa<PyObjectRef>(out)) { if (utils::isa<PyObjectRef>(out)) {
PyObjectRef py_ref = utils::cast<PyObjectRef>(out); PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
auto out_py_tuple = py_ref.object_; auto out_py_tuple = py_ref.object_;
std::vector<tensor::TensorPtr> output_tensors; std::vector<ValuePtr> output_tensors;
ConvertPyObjectToTensor(out_py_tuple, &output_tensors); ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_), (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); }); [](ValuePtr &tensor) { return std::move(tensor); });
} }
} }
} }

View File

@ -82,6 +82,9 @@ COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple); COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode); COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
COMMON_EXPORT tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
const tensor::TensorPtrList &tensor_list);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_ #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_

View File

@ -74,6 +74,8 @@ const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description"; const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
const char PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR[] = "convert_to_ms_csrtensor";
const char PYTHON_MOD_CONVERT_TO_MS_COOTENSOR[] = "convert_to_ms_cootensor";
const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script"; const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids"; const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance"; const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";

View File

@ -63,6 +63,12 @@ void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_ar
if (py::isinstance<tensor::Tensor>(input_args[i])) { if (py::isinstance<tensor::Tensor>(input_args[i])) {
(*convert_args)[i] = (*convert_args)[i] =
python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]); python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]);
} else if (py::isinstance<tensor::CSRTensor>(input_args[i])) {
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_CSRTENSOR, input_args[i]);
} else if (py::isinstance<tensor::COOTensor>(input_args[i])) {
(*convert_args)[i] = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_COOTENSOR, input_args[i]);
} else if (py::isinstance<py::tuple>(input_args[i])) { } else if (py::isinstance<py::tuple>(input_args[i])) {
auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]); auto tuple_inp_arg = py::cast<py::tuple>(input_args[i]);
py::tuple convert_tuple_arg(tuple_inp_arg.size()); py::tuple convert_tuple_arg(tuple_inp_arg.size());

View File

@ -348,4 +348,52 @@ bool IsAKGSparseOP(const AnfNodePtr &cnode) {
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM}; prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv, prim::kPrimCSRMM};
return IsOneOfPrimitiveCNode(cnode, prims); return IsOneOfPrimitiveCNode(cnode, prims);
} }
namespace {
ShapeVector ConvertTensorListToShapeVector(const tensor::TensorPtrList &tensor_list, size_t index) {
ShapeVector shape;
if (index >= tensor_list.size()) {
MS_LOG(EXCEPTION) << "Index " << index << " is out of range of " << tensor_list.size();
return shape;
}
auto converter = [](tensor::TensorPtr tensorptr) {
MS_EXCEPTION_IF_NULL(tensorptr);
if (tensorptr->DataDim() != 0) {
MS_LOG(EXCEPTION) << "Element must be scalar!";
}
tensorptr->data_sync(false);
return *(static_cast<int64_t *>(tensorptr->data_c()));
};
std::transform(tensor_list.begin() + index, tensor_list.end(), std::back_inserter(shape), converter);
if (shape.empty()) {
MS_LOG(ERROR) << "ShapeVector is empty!";
}
return shape;
}
tensor::CSRTensorPtr TensorListToCSRTensor(const tensor::TensorPtrList &tensor_list) {
tensor::TensorPtr indptr = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndptrIdx]);
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kIndicesIdx]);
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::CSRTensor::kValuesIdx]);
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::CSRTensor::kShapeIdx);
auto csr_tensor_ptr = std::make_shared<tensor::CSRTensor>(indptr, indices, values, shape);
return csr_tensor_ptr;
}
tensor::COOTensorPtr TensorListToCOOTensor(const tensor::TensorPtrList &tensor_list) {
tensor::TensorPtr indices = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kIndicesIdx]);
tensor::TensorPtr values = utils::cast<tensor::TensorPtr>(tensor_list[tensor::COOTensor::kValuesIdx]);
ShapeVector shape = ConvertTensorListToShapeVector(tensor_list, tensor::COOTensor::kShapeIdx);
auto coo_tensor_ptr = std::make_shared<tensor::COOTensor>(indices, values, shape);
return coo_tensor_ptr;
}
} // namespace
tensor::MetaSparseTensorPtr TensorListToSparseTensor(const abstract::AbstractBasePtr &abs_sparse,
const tensor::TensorPtrList &tensor_list) {
if (abs_sparse->isa<abstract::AbstractCOOTensor>()) {
return TensorListToCOOTensor(tensor_list);
}
return TensorListToCSRTensor(tensor_list);
}
} // namespace mindspore } // namespace mindspore

View File

@ -1579,6 +1579,8 @@ AbstractBasePtr AbstractCOOTensor::Broaden() const {
return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden()); return std::make_shared<abstract::AbstractCOOTensor>(ElementsBroaden());
} }
AbstractBasePtr AbstractCOOTensor::PartialBroaden() const { return Broaden(); }
std::string AbstractCOOTensor::ToString() const { std::string AbstractCOOTensor::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
@ -1630,6 +1632,8 @@ AbstractBasePtr AbstractCSRTensor::Broaden() const {
return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden()); return std::make_shared<abstract::AbstractCSRTensor>(ElementsBroaden());
} }
AbstractBasePtr AbstractCSRTensor::PartialBroaden() const { return Broaden(); }
std::string AbstractCSRTensor::ToString() const { std::string AbstractCSRTensor::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("

View File

@ -1461,6 +1461,7 @@ class MS_CORE_API AbstractCOOTensor : public AbstractSparseTensor {
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override; AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override; std::string ToString() const override;
static constexpr size_t kIndicesIdx = 0; static constexpr size_t kIndicesIdx = 0;
@ -1489,6 +1490,7 @@ class MS_CORE_API AbstractCSRTensor : public AbstractSparseTensor {
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override; AbstractBasePtr Broaden() const override;
AbstractBasePtr PartialBroaden() const override;
std::string ToString() const override; std::string ToString() const override;
static constexpr size_t kIndptrIdx = 0; static constexpr size_t kIndptrIdx = 0;

View File

@ -24,7 +24,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol, eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance) is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope', __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol', 'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -33,4 +34,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name', 'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol', 'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name', 'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance'] 'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
'convert_to_ms_cootensor']

View File

@ -31,7 +31,7 @@ import numpy
import asttokens import asttokens
import astunparse import astunparse
from mindspore import Tensor from mindspore import Tensor, CSRTensor, COOTensor
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore import ops from mindspore import ops
@ -497,6 +497,16 @@ def convert_to_ms_tensor(data):
return Tensor(data) return Tensor(data)
def convert_to_ms_csrtensor(data):
"""Convert C++ csrtensor to mindspore csrtensor."""
return CSRTensor(csr_tensor=data)
def convert_to_ms_cootensor(data):
"""Convert C++ cootensor to mindspore cootensor."""
return COOTensor(coo_tensor=data)
def get_object_description(obj, fname, fline): def get_object_description(obj, fname, fline):
"""Return method or funcition description for error report, include location, class name, etc.""" """Return method or funcition description for error report, include location, class name, etc."""
if isinstance(obj, types.MethodType): if isinstance(obj, types.MethodType):

View File

@ -243,7 +243,8 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b):
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8)]) @pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8),
('CSRTensor', onp.float32, 1e-5)])
@pytest.mark.parametrize('a, b, grad_a, grad_b', [ @pytest.mark.parametrize('a, b, grad_a, grad_b', [
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143], ([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368], [0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
@ -278,6 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
Description: test cases for grad implementation of cg in pynative mode Description: test cases for grad implementation of cg in pynative mode
Expectation: the result match expectation Expectation: the result match expectation
""" """
if tensor_type == "CSRTensor" and get_platform() != "linux":
return
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
a = to_tensor((a, tensor_type), dtype) a = to_tensor((a, tensor_type), dtype)