diff --git a/mindspore/ccsrc/backend/common/session/session_basic.cc b/mindspore/ccsrc/backend/common/session/session_basic.cc index da6cadeff7f..12e280318ae 100644 --- a/mindspore/ccsrc/backend/common/session/session_basic.cc +++ b/mindspore/ccsrc/backend/common/session/session_basic.cc @@ -445,7 +445,7 @@ void SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel, const InputTenso auto prim = common::AnfAlgo::GetCNodePrimitive(kernel); MS_EXCEPTION_IF_NULL(prim); buf << GetOpRunDeviceTarget(prim) << "_"; - buf << prim->id() << "_"; + buf << prim->name() << "_"; bool has_const_input = false; for (size_t i = 0; i < input_tensors.size(); ++i) { auto &tensor = input_tensors[i]; diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index e20aeb17aaf..79d35d31a58 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -89,7 +89,6 @@ struct InputArgsInfo { size_t input_size; std::string obj_id; bool has_sens{false}; - bool is_run_cell{false}; bool use_dynamic_shape_process = false; PrimitivePyPtr custom_bprp_prim{nullptr}; ValuePtr out_value{nullptr}; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc index 58615ae2166..73c88cd557a 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/grad.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/grad.cc @@ -63,10 +63,24 @@ std::string GetCellId(const py::object &obj, const py::args &args, const InputAr return cell_id; } +std::string GetFnInfoByPyObj(const py::object &obj) { + auto module_name = obj.attr("__module__").cast(); + auto fn_name = obj.attr("__name__").cast(); + auto filename = obj.attr("__code__").attr("co_filename").cast(); + auto code_lineno = py::str(obj.attr("__code__").attr("co_firstlineno")).cast(); + return (module_name + "_" + fn_name + "_" + filename + "_" + code_lineno); +} + InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell, bool is_high_order_top_cell) { bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME); - const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj); + std::string obj_id; + if (!py::isinstance(obj) && (is_grad_top_cell || is_high_order_top_cell)) { + obj_id = GetFnInfoByPyObj(obj); + } else { + obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj); + } + const auto &input_args_info = std::make_shared(is_grad_top_cell, is_high_order_top_cell, has_custom_bprop, args.size(), obj_id); for (size_t i = 0; i < args.size(); i++) { @@ -82,7 +96,6 @@ InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::arg } pipeline::CheckArgsValid(obj, args); } - input_args_info->is_run_cell = py::isinstance(obj); input_args_info->cell_id = GetCellId(obj, args, input_args_info); MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell); return input_args_info; @@ -522,7 +535,6 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) { std::make_shared(input_args_info->is_high_order_top_cell, input_args_info->grad_order, input_args_info->obj_id, input_args_info->cell_id, already_run_cell_id, resource, fg); top_cell_->set_forward_already_run(true); - top_cell_->set_is_run_cell(input_args_info->is_run_cell); top_cell_->set_input_args_id(input_args_info->input_args_id); PushHighOrderGraphStack(top_cell_); (void)top_cell_list_.emplace_back(top_cell_); @@ -727,7 +739,7 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id); MS_EXCEPTION_IF_NULL(pre_top_cell); - if (input_args_info->use_dynamic_shape_process || !input_args_info->is_run_cell) { + if (input_args_info->use_dynamic_shape_process) { // Function need compile every time. MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"; EraseTopCellFromTopCellList(pre_top_cell); @@ -1798,7 +1810,7 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node, const std::string &graph_phase) const { - if (!top_cell()->is_run_cell() || use_dynamic_shape_process_) { + if (use_dynamic_shape_process_) { return; } diff --git a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h index e1b918152d5..996ea8327cd 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/top_cell.h @@ -101,8 +101,6 @@ class TopCellInfo { inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) { graph_info_map_[fg] = graph_info; } - inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; } - inline bool is_run_cell() const { return is_run_cell_; } inline const OrderedMap &graph_info_map() const { return graph_info_map_; } inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_); @@ -152,7 +150,6 @@ class TopCellInfo { bool is_init_kpynative_{false}; bool forward_already_run_{false}; bool need_compile_graph_{false}; - bool is_run_cell_{false}; size_t op_index_{0}; bool is_high_order_top_cell_{false}; bool need_do_final_opt_{false}; diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py index d6cb0e6b99f..9a3fe870407 100644 --- a/tests/st/scipy_st/sparse/test_linalg.py +++ b/tests/st/scipy_st/sparse/test_linalg.py @@ -243,8 +243,7 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b): @pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8), - ('CSRTensor', onp.float32, 1e-5)]) +@pytest.mark.parametrize('tensor_type, dtype, tol', [('Tensor', onp.float32, 1e-5), ('Tensor', onp.float64, 1e-8)]) @pytest.mark.parametrize('a, b, grad_a, grad_b', [ ([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143], [0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368], @@ -273,7 +272,123 @@ def test_cg_grad(flatten, tensor_type, dtype, tol, a, b, grad_a, grad_b): [-0.14053766, 0.00313851, 0.02536103, 0.01889718, -0.07065797]], [0.23398106, 0.31016481, 0.29870068, -0.09782316, 0.43852141]), ]) -def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b): +def test_cg_grad_pynative_tensor(tensor_type, dtype, tol, a, b, grad_a, grad_b): + """ + Feature: ALL TO ALL + Description: test cases for grad implementation of cg in pynative mode + Expectation: the result match expectation + """ + if tensor_type == "CSRTensor" and get_platform() != "linux": + return + context.set_context(mode=context.PYNATIVE_MODE) + + a = to_tensor((a, tensor_type), dtype) + b = Tensor(onp.array(b, dtype=dtype)) + expect_grad_a = onp.array(grad_a, dtype=dtype) + expect_grad_b = onp.array(grad_b, dtype=dtype) + + # Function + grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg) + grad_a, grad_b = grad_net(a, b)[:2] + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), rtol=tol, atol=tol) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), rtol=tol, atol=tol) + + # Cell + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sum = ops.ReduceSum() + self.cg = msp.sparse.linalg.cg + + def construct(self, a, b): + x, _ = self.cg(a, b) + return self.sum(x) + + grad_net = ops.GradOperation(get_all=True)(Net()) + grad_a, grad_b = grad_net(a, b)[:2] + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), rtol=tol, atol=tol) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), rtol=tol, atol=tol) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('tensor_type, dtype, tol', [('CSRTensor', onp.float32, 1e-5)]) +@pytest.mark.parametrize('a, b, grad_a, grad_b', [ + ([[1.85910724, 0.73233206, 0.65960803, 1.03821349, 0.55277616], + [0.73233206, 1.69548841, 0.59992146, 1.01518264, 0.50824059], + [0.65960803, 0.59992146, 1.98169091, 1.45565213, 0.47901749], + [1.03821349, 1.01518264, 1.45565213, 3.3133049, 0.75598147], + [0.55277616, 0.50824059, 0.47901749, 0.75598147, 1.46831254]], + [0.59674531, 0.226012, 0.10694568, 0.22030621, 0.34982629], + [[-0.07498642, 0.00167461, 0.01353184, 0.01008293, -0.03770084], + [-0.09940184, 0.00221986, 0.01793778, 0.01336592, -0.04997616], + [-0.09572781, 0.00213781, 0.01727477, 0.01287189, -0.04812897], + [0.03135044, -0.00070012, -0.00565741, -0.00421549, 0.01576203], + [-0.14053766, 0.00313851, 0.02536103, 0.01889718, -0.07065797]], + [0.23398106, 0.31016481, 0.29870068, -0.09782316, 0.43852141]), +]) +def test_cg_grad_pynative_csrtensor_data1(tensor_type, dtype, tol, a, b, grad_a, grad_b): + """ + Feature: ALL TO ALL + Description: test cases for grad implementation of cg in pynative mode + Expectation: the result match expectation + """ + if tensor_type == "CSRTensor" and get_platform() != "linux": + return + context.set_context(mode=context.PYNATIVE_MODE) + + a = to_tensor((a, tensor_type), dtype) + b = Tensor(onp.array(b, dtype=dtype)) + expect_grad_a = onp.array(grad_a, dtype=dtype) + expect_grad_b = onp.array(grad_b, dtype=dtype) + + # Function + grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg) + grad_a, grad_b = grad_net(a, b)[:2] + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), rtol=tol, atol=tol) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), rtol=tol, atol=tol) + + # Cell + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sum = ops.ReduceSum() + self.cg = msp.sparse.linalg.cg + + def construct(self, a, b): + x, _ = self.cg(a, b) + return self.sum(x) + + grad_net = ops.GradOperation(get_all=True)(Net()) + grad_a, grad_b = grad_net(a, b)[:2] + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), rtol=tol, atol=tol) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), rtol=tol, atol=tol) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('tensor_type, dtype, tol', [('CSRTensor', onp.float32, 1e-5)]) +@pytest.mark.parametrize('a, b, grad_a, grad_b', [ + ([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143], + [0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368], + [1.03749232, 1.40235719, 2.90618746, 0.7126087, 0.81029544, 1.28673025], + [0.88915326, 0.70838919, 0.7126087, 2.17515263, 0.40443765, 1.02082996], + [0.44986806, 0.81377919, 0.81029544, 0.40443765, 1.60570668, 0.62292701], + [1.11167143, 1.06000368, 1.28673025, 1.02082996, 0.62292701, 2.30795277]], + [0.79363745, 0.58000418, 0.1622986, 0.70075235, 0.96455108, 0.50000836], + [[-0.07867674, -0.01521201, 0.06394698, -0.03854052, -0.13523701, 0.01326866], + [-0.03508505, -0.00678363, 0.02851647, -0.01718673, -0.06030749, 0.00591702], + [-0.00586019, -0.00113306, 0.00476305, -0.00287067, -0.01007304, 0.00098831], + [-0.07704304, -0.01489613, 0.06261914, -0.03774023, -0.13242886, 0.01299314], + [-0.14497008, -0.02802971, 0.11782896, -0.07101491, -0.24918826, 0.02444888], + [-0.01868565, -0.00361284, 0.01518735, -0.00915334, -0.03211867, 0.00315129]], + [0.22853142, 0.10191113, 0.01702201, 0.22378603, 0.42109291, 0.054276]), +]) +def test_cg_grad_pynative_csrtensor_data2(tensor_type, dtype, tol, a, b, grad_a, grad_b): """ Feature: ALL TO ALL Description: test cases for grad implementation of cg in pynative mode