diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index ef76c4d4985..6c98c437ae8 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -34,6 +34,7 @@ #include "pybind_api/api_register.h" #include "ir/signature.h" #include "debug/trace.h" +#include "utils/ms_context.h" namespace mindspore { // namespace to support composite operators definition @@ -403,7 +404,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & if (tail_type_ == kGradFirst) { if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && ((*sequeue)[1]->isa() || - ((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa()))) { + (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr && + (*sequeue)[1]->BuildType()->isa()))) { ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); } else { ret->set_output(NewValueNode(std::make_shared(std::vector{}))); @@ -416,7 +418,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & if (tail_type_ == kGradAll) { MS_EXCEPTION_IF_NULL((*sequeue)[i]); if ((*sequeue)[i]->isa() || - ((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa())) { + (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr && + (*sequeue)[i]->BuildType()->isa())) { elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); } } else { diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 09401c4b892..69361d6671e 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -490,7 +490,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { } AbstractBasePtr par_abs = param_node->abstract(); if (par_abs->isa() || - (par_abs->BuildType() != nullptr && par_abs->BuildType()->isa())) { + (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr && + par_abs->BuildType()->isa())) { new_paras.push_back(param_node); } } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index c82a7e7fe82..ea41db48a31 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -98,7 +98,8 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { MS_EXCEPTION_IF_NULL(value); - bool broaden = value->isa() || value->isa(); + bool broaden = value->isa() || + (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && value->isa()); return abstract::FromValue(value, broaden); } diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index 6554db29534..35f696e2433 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -95,7 +95,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) - .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH); + .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH) + .value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR); (void)py::class_>(*m, "MSContext") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.") diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 257ee65e195..d6ba6e90b90 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -210,7 +210,9 @@ class _MindSporeFunction: return None new_inputs = [] for i in args_list: - if isinstance(i, (Tensor, int, float)): + if isinstance(i, Tensor): + new_inputs.append(i) + elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): new_inputs.append(i) return self._executor(tuple(new_inputs), phase) diff --git a/mindspore/context.py b/mindspore/context.py index 4ea8c23373a..e15a8f4cfcd 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -533,6 +533,7 @@ def set_context(**kwargs): save_graphs variable_memory_max_size save_graphs_path env_config_path + grad_for_scalar =========================== =========================== ================= Args: @@ -602,6 +603,7 @@ def set_context(**kwargs): enable_sparse (bool): Whether to enable sparsity feature. Default: False. max_call_depth (int): Specify the maximum depth of function call. Default: 1000. env_config_path (str): Config path for DFX. + grad_for_scalar (bool): Whether to get gradient for scalar. Default: False. Raises: ValueError: If input key is not an attribute in context. diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 8dd3e420658..7c8665a6ea9 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -22,6 +22,7 @@ #include "utils/symbolic.h" #include "abstract/utils.h" +#include "utils/ms_context.h" namespace mindspore { namespace abstract { @@ -88,7 +89,13 @@ std::string AbstractBase::ToString() const { return buffer.str(); } -AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } +AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { + if (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR)) { + return AbstractBase::Broaden(config); + } else { + return Clone(); + } +} AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index c8af8416be3..7a15d3a0674 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -171,6 +171,12 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p return args_spec_list[0]; } auto depends = args_spec_list[0]->Broaden(); + if (!MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR)) { + // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. + if (depends->isa()) { + depends->set_value(kAnyValue); + } + } return depends; } diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 5ea13730937..54c6dba8378 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_ENABLE_GRAPH_KERNEL, false); set_param(MS_CTX_ENABLE_SPARSE, false); set_param(MS_CTX_ENABLE_PARALLEL_SPLIT, false); + set_param(MS_CTX_GRAD_FOR_SCALAR, false); backend_policy_ = policy_map_[policy]; } diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 683394efea1..11c81992fbb 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -76,6 +76,7 @@ enum MsCtxParam : unsigned { MS_CTX_SAVE_GRAPHS_FLAG, MS_CTX_ENABLE_PARALLEL_SPLIT, MS_CTX_ENABLE_INFER_OPT, + MS_CTX_GRAD_FOR_SCALAR, MS_CTX_TYPE_BOOL_END, // parameter of type int diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 68e73152115..817b0091d3a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -609,7 +609,9 @@ class Cell(Cell_): new_inputs = [] for i in inputs: - if isinstance(i, (Tensor, int, float)): + if isinstance(i, Tensor): + new_inputs.append(i) + elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): new_inputs.append(i) if self._auto_parallel_mode: diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/abstract/utils_test.cc index ea954c0641f..ff44c1c0409 100644 --- a/tests/ut/cpp/abstract/utils_test.cc +++ b/tests/ut/cpp/abstract/utils_test.cc @@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) { AbstractBasePtr abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr abs_s2 = FromValue(static_cast(2), false); AbstractBasePtr abs_s_anything = FromValue(static_cast(2), true); + abs_s_anything->set_value(kAnyValue); AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); ASSERT_EQ(*res_s1, *abs_s_anything); - // AbstractTuple join; - std::vector list1 = {1, 2, 3, 4, 5}; - std::vector list2 = {5, 4, 3, 2, 1}; - AbstractBasePtr abs_t1 = FromValue(list1, true); - AbstractBasePtr abs_t2 = FromValue(list2, true); - - AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); - ASSERT_EQ(res_t1, abs_t1); - abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr t1 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t2 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t3 = std::make_shared(AbstractBasePtrList({abs_s_anything, abs_s_anything})); - res_t1 = t1->Join(t2); + AbstractBasePtr res_t1 = t1->Join(t2); ASSERT_EQ(res_t1, t1); res_t1 = t1->Join(t3); diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 3c794f97a8d..38881362d5a 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) { // add infer and renormalize std::shared_ptr res = std::make_shared(); AbstractBasePtrList args_spec_list; - AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast(1), true); - AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast(2), true); + tensor::TensorPtr x_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); + tensor::TensorPtr y_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); + + AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); + AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v2); AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index 5c333ed52f9..a163b9db147 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) { AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); AbstractBasePtr abs_func_tuple = std::make_shared(AbstractBasePtrList({abs_f1, abs_f2})); ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); - ASSERT_EQ(*func_tuple_built, - ValueTuple(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); + ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); // BuildValue(List(AbstractFunction)) should return kAnyValue; AbstractBasePtr abs_func_list = std::make_shared(AbstractBasePtrList({abs_f1, abs_f2})); ValuePtr func_list_built = abs_func_list->BuildValue(); - ASSERT_EQ(*func_list_built, - ValueList(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); + ASSERT_EQ(*func_list_built, ValueList(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue abs_func_tuple = std::make_shared(AbstractBasePtrList({base1, abs_f2})); func_tuple_built = abs_func_tuple->BuildValue(); - ASSERT_EQ(*func_tuple_built, - ValueTuple(std::vector{std::make_shared(1), prim::kPrimScalarAdd})); + ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector{std::make_shared(1), prim::kPrimScalarAdd})); } TEST_F(TestData, test_build_type) { @@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) { AbstractBasePtr abstract_tup = FromValue(vec, true); std::shared_ptr shape_tuple = dyn_cast(abstract_tup->BuildShape()); ASSERT_TRUE(shape_tuple); - const std::vector& ptr_vec = shape_tuple->shape(); + const std::vector &ptr_vec = shape_tuple->shape(); ASSERT_EQ(ptr_vec.size(), 2); ShapePtr shape1 = dyn_cast(ptr_vec[0]); @@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) { ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack()); ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack()); - AbstractFunctionPtr f1 = std::make_shared(std::make_shared(), - AnalysisContext::DummyContext()); + AbstractFunctionPtr f1 = + std::make_shared(std::make_shared(), AnalysisContext::DummyContext()); AbstractBasePtr f2 = f1->Clone(); ASSERT_TRUE(*f2 == *f1); AbstractList l1 = AbstractList({s1, s2}); AbstractBasePtr l2 = l1.Clone(); - AbstractList* l2_cast = dynamic_cast(l2.get()); + AbstractList *l2_cast = dynamic_cast(l2.get()); ASSERT_TRUE(l2_cast != nullptr); ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack()); @@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) { AbstractBasePtr s2 = s1->Broaden(); ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); - ASSERT_TRUE(s2->GetValueTrack()->isa()); + ASSERT_TRUE(s2->GetValueTrack()->isa()); - AbstractFunctionPtr f1 = std::make_shared(std::make_shared(), - AnalysisContext::DummyContext()); + AbstractFunctionPtr f1 = + std::make_shared(std::make_shared(), AnalysisContext::DummyContext()); AbstractBasePtr f2 = f1->Broaden(); ASSERT_TRUE(f2 == f1); AbstractList l1 = AbstractList({s1, s2}); AbstractBasePtr l2 = l1.Broaden(); - AbstractList* l2_cast = dynamic_cast(l2.get()); + AbstractList *l2_cast = dynamic_cast(l2.get()); ASSERT_TRUE(l2_cast != nullptr); AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); - ASSERT_TRUE(csr->GetValueTrack()->isa()); + ASSERT_TRUE(csr->GetValueTrack()->isa()); } } // namespace abstract diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 5140960451c..97166546410 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -14,7 +14,6 @@ # ============================================================================ """ test_framstruct """ import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn from mindspore import context @@ -76,9 +75,7 @@ def dynamic_make_tuple(x, lower, upper): def test_dynamic_make_tuple(): - # Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language. - with pytest.raises(RuntimeError): - dynamic_make_tuple(2, 1, 5) + assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) def test_make_tuple():