!12703 Add a switch to control grad for scalar

From: @ginfung
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-27 20:53:42 +08:00 committed by Gitee
commit 614cf339ad
15 changed files with 55 additions and 39 deletions

View File

@ -34,6 +34,7 @@
#include "pybind_api/api_register.h"
#include "ir/signature.h"
#include "debug/trace.h"
#include "utils/ms_context.h"
namespace mindspore {
// namespace to support composite operators definition
@ -403,7 +404,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
if (tail_type_ == kGradFirst) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) {
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()))) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
} else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
@ -416,7 +418,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) {
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
(*sequeue)[i]->BuildType()->isa<Number>())) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
}
} else {

View File

@ -490,7 +490,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
}
AbstractBasePtr par_abs = param_node->abstract();
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) {
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>())) {
new_paras.push_back(param_node);
}
}

View File

@ -98,7 +98,8 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name)
AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>();
bool broaden = value->isa<MetaTensor>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
return abstract::FromValue(value, broaden);
}

View File

@ -95,7 +95,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH);
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")

View File

@ -210,7 +210,9 @@ class _MindSporeFunction:
return None
new_inputs = []
for i in args_list:
if isinstance(i, (Tensor, int, float)):
if isinstance(i, Tensor):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
return self._executor(tuple(new_inputs), phase)

View File

@ -533,6 +533,7 @@ def set_context(**kwargs):
save_graphs variable_memory_max_size
save_graphs_path
env_config_path
grad_for_scalar
=========================== =========================== =================
Args:
@ -602,6 +603,7 @@ def set_context(**kwargs):
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth (int): Specify the maximum depth of function call. Default: 1000.
env_config_path (str): Config path for DFX.
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
Raises:
ValueError: If input key is not an attribute in context.

View File

@ -22,6 +22,7 @@
#include "utils/symbolic.h"
#include "abstract/utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace abstract {
@ -88,7 +89,13 @@ std::string AbstractBase::ToString() const {
return buffer.str();
}
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
return AbstractBase::Broaden(config);
} else {
return Clone();
}
}
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);

View File

@ -171,6 +171,12 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0];
}
auto depends = args_spec_list[0]->Broaden();
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue);
}
}
return depends;
}

View File

@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
backend_policy_ = policy_map_[policy];
}

View File

@ -76,6 +76,7 @@ enum MsCtxParam : unsigned {
MS_CTX_SAVE_GRAPHS_FLAG,
MS_CTX_ENABLE_PARALLEL_SPLIT,
MS_CTX_ENABLE_INFER_OPT,
MS_CTX_GRAD_FOR_SCALAR,
MS_CTX_TYPE_BOOL_END,
// parameter of type int

View File

@ -609,7 +609,9 @@ class Cell(Cell_):
new_inputs = []
for i in inputs:
if isinstance(i, (Tensor, int, float)):
if isinstance(i, Tensor):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
if self._auto_parallel_mode:

View File

@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) {
AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
abs_s_anything->set_value(kAnyValue);
AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
ASSERT_EQ(*res_s1, *abs_s_anything);
// AbstractTuple join;
std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
AbstractBasePtr abs_t1 = FromValue(list1, true);
AbstractBasePtr abs_t2 = FromValue(list2, true);
AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
ASSERT_EQ(res_t1, abs_t1);
abs_s1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
res_t1 = t1->Join(t2);
AbstractBasePtr res_t1 = t1->Join(t2);
ASSERT_EQ(res_t1, t1);
res_t1 = t1->Join(t3);

View File

@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) {
// add infer and renormalize
std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>();
AbstractBasePtrList args_spec_list;
AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true);
AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true);
tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3});
AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true);
AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true);
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list);

View File

@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) {
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(*func_tuple_built,
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(List(AbstractFunction)) should return kAnyValue;
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_list_built = abs_func_list->BuildValue();
ASSERT_EQ(*func_list_built,
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(*func_tuple_built,
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
}
TEST_F(TestData, test_build_type) {
@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) {
AbstractBasePtr abstract_tup = FromValue(vec, true);
std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
ASSERT_TRUE(shape_tuple);
const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape();
const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape();
ASSERT_EQ(ptr_vec.size(), 2);
ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) {
ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
AnalysisContext::DummyContext());
AbstractFunctionPtr f1 =
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
AbstractBasePtr f2 = f1->Clone();
ASSERT_TRUE(*f2 == *f1);
AbstractList l1 = AbstractList({s1, s2});
AbstractBasePtr l2 = l1.Clone();
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
ASSERT_TRUE(l2_cast != nullptr);
ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) {
AbstractBasePtr s2 = s1->Broaden();
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
AnalysisContext::DummyContext());
AbstractFunctionPtr f1 =
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
AbstractBasePtr f2 = f1->Broaden();
ASSERT_TRUE(f2 == f1);
AbstractList l1 = AbstractList({s1, s2});
AbstractBasePtr l2 = l1.Broaden();
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
ASSERT_TRUE(l2_cast != nullptr);
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
}
} // namespace abstract

View File

@ -14,7 +14,6 @@
# ============================================================================
""" test_framstruct """
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
@ -76,9 +75,7 @@ def dynamic_make_tuple(x, lower, upper):
def test_dynamic_make_tuple():
# Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language.
with pytest.raises(RuntimeError):
dynamic_make_tuple(2, 1, 5)
assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
def test_make_tuple():