forked from mindspore-Ecosystem/mindspore
!12703 Add a switch to control grad for scalar
From: @ginfung Reviewed-by: Signed-off-by:
This commit is contained in:
commit
614cf339ad
|
@ -34,6 +34,7 @@
|
|||
#include "pybind_api/api_register.h"
|
||||
#include "ir/signature.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
|
@ -403,7 +404,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
|
|||
if (tail_type_ == kGradFirst) {
|
||||
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
|
||||
((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) {
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
|
||||
(*sequeue)[1]->BuildType()->isa<Number>()))) {
|
||||
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
|
||||
} else {
|
||||
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
|
||||
|
@ -416,7 +418,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
|
|||
if (tail_type_ == kGradAll) {
|
||||
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
|
||||
((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) {
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
|
||||
(*sequeue)[i]->BuildType()->isa<Number>())) {
|
||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -490,7 +490,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
AbstractBasePtr par_abs = param_node->abstract();
|
||||
if (par_abs->isa<abstract::AbstractUndetermined>() ||
|
||||
(par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) {
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
|
||||
par_abs->BuildType()->isa<Number>())) {
|
||||
new_paras.push_back(param_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,7 +98,8 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name)
|
|||
|
||||
AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>();
|
||||
bool broaden = value->isa<MetaTensor>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
||||
return abstract::FromValue(value, broaden);
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH);
|
||||
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")
|
||||
|
|
|
@ -210,7 +210,9 @@ class _MindSporeFunction:
|
|||
return None
|
||||
new_inputs = []
|
||||
for i in args_list:
|
||||
if isinstance(i, (Tensor, int, float)):
|
||||
if isinstance(i, Tensor):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
return self._executor(tuple(new_inputs), phase)
|
||||
|
||||
|
|
|
@ -533,6 +533,7 @@ def set_context(**kwargs):
|
|||
save_graphs variable_memory_max_size
|
||||
save_graphs_path
|
||||
env_config_path
|
||||
grad_for_scalar
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
|
@ -602,6 +603,7 @@ def set_context(**kwargs):
|
|||
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
|
||||
max_call_depth (int): Specify the maximum depth of function call. Default: 1000.
|
||||
env_config_path (str): Config path for DFX.
|
||||
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "utils/symbolic.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -88,7 +89,13 @@ std::string AbstractBase::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
|
||||
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
return AbstractBase::Broaden(config);
|
||||
} else {
|
||||
return Clone();
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
|
|
|
@ -171,6 +171,12 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return args_spec_list[0];
|
||||
}
|
||||
auto depends = args_spec_list[0]->Broaden();
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
|
||||
if (depends->isa<AbstractScalar>()) {
|
||||
depends->set_value(kAnyValue);
|
||||
}
|
||||
}
|
||||
return depends;
|
||||
}
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
|
||||
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_SAVE_GRAPHS_FLAG,
|
||||
MS_CTX_ENABLE_PARALLEL_SPLIT,
|
||||
MS_CTX_ENABLE_INFER_OPT,
|
||||
MS_CTX_GRAD_FOR_SCALAR,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
|
@ -609,7 +609,9 @@ class Cell(Cell_):
|
|||
|
||||
new_inputs = []
|
||||
for i in inputs:
|
||||
if isinstance(i, (Tensor, int, float)):
|
||||
if isinstance(i, Tensor):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
|
||||
if self._auto_parallel_mode:
|
||||
|
|
|
@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
|
|||
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
||||
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
|
||||
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
|
||||
abs_s_anything->set_value(kAnyValue);
|
||||
|
||||
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
|
||||
ASSERT_EQ(*res_s1, *abs_s_anything);
|
||||
|
||||
// AbstractTuple join;
|
||||
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
|
||||
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
|
||||
AbstractBasePtr abs_t1 = FromValue(list1, true);
|
||||
AbstractBasePtr abs_t2 = FromValue(list2, true);
|
||||
|
||||
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
|
||||
ASSERT_EQ(res_t1, abs_t1);
|
||||
|
||||
abs_s1 = FromValue(static_cast<int64_t>(1), false);
|
||||
|
||||
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
||||
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
|
||||
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
|
||||
|
||||
res_t1 = t1->Join(t2);
|
||||
AbstractBasePtr res_t1 = t1->Join(t2);
|
||||
ASSERT_EQ(res_t1, t1);
|
||||
|
||||
res_t1 = t1->Join(t3);
|
||||
|
|
|
@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
|
|||
// add infer and renormalize
|
||||
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
|
||||
AbstractBasePtrList args_spec_list;
|
||||
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true);
|
||||
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true);
|
||||
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
|
||||
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
|
||||
|
||||
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
|
||||
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
|
||||
args_spec_list.push_back(abstract_v1);
|
||||
args_spec_list.push_back(abstract_v2);
|
||||
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);
|
||||
|
|
|
@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) {
|
|||
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
|
||||
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
|
||||
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
|
||||
ASSERT_EQ(*func_tuple_built,
|
||||
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
|
||||
// BuildValue(List(AbstractFunction)) should return kAnyValue;
|
||||
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
|
||||
ValuePtr func_list_built = abs_func_list->BuildValue();
|
||||
ASSERT_EQ(*func_list_built,
|
||||
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
|
||||
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
|
||||
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
|
||||
func_tuple_built = abs_func_tuple->BuildValue();
|
||||
ASSERT_EQ(*func_tuple_built,
|
||||
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
|
||||
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
|
||||
}
|
||||
|
||||
TEST_F(TestData, test_build_type) {
|
||||
|
@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) {
|
|||
AbstractBasePtr abstract_tup = FromValue(vec, true);
|
||||
std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
|
||||
ASSERT_TRUE(shape_tuple);
|
||||
const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape();
|
||||
const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape();
|
||||
ASSERT_EQ(ptr_vec.size(), 2);
|
||||
|
||||
ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
|
||||
|
@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) {
|
|||
ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
|
||||
ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
|
||||
|
||||
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
||||
AnalysisContext::DummyContext());
|
||||
AbstractFunctionPtr f1 =
|
||||
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
|
||||
AbstractBasePtr f2 = f1->Clone();
|
||||
ASSERT_TRUE(*f2 == *f1);
|
||||
|
||||
AbstractList l1 = AbstractList({s1, s2});
|
||||
AbstractBasePtr l2 = l1.Clone();
|
||||
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
||||
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
|
||||
ASSERT_TRUE(l2_cast != nullptr);
|
||||
ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
|
||||
|
||||
|
@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) {
|
|||
AbstractBasePtr s2 = s1->Broaden();
|
||||
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
|
||||
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
|
||||
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
|
||||
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
|
||||
|
||||
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
||||
AnalysisContext::DummyContext());
|
||||
AbstractFunctionPtr f1 =
|
||||
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
|
||||
AbstractBasePtr f2 = f1->Broaden();
|
||||
ASSERT_TRUE(f2 == f1);
|
||||
|
||||
AbstractList l1 = AbstractList({s1, s2});
|
||||
AbstractBasePtr l2 = l1.Broaden();
|
||||
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
||||
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
|
||||
ASSERT_TRUE(l2_cast != nullptr);
|
||||
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
|
||||
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
|
||||
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
""" test_framstruct """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
|
@ -76,9 +75,7 @@ def dynamic_make_tuple(x, lower, upper):
|
|||
|
||||
|
||||
def test_dynamic_make_tuple():
|
||||
# Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language.
|
||||
with pytest.raises(RuntimeError):
|
||||
dynamic_make_tuple(2, 1, 5)
|
||||
assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
|
||||
|
||||
|
||||
def test_make_tuple():
|
||||
|
|
Loading…
Reference in New Issue