!32373 Add class Variable to set constant mutable

Merge pull request !32373 from YuJianfeng/mutable
This commit is contained in:
i-robot 2022-04-02 07:35:02 +00:00 committed by Gitee
commit b12fefbded
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 1152 additions and 67 deletions

View File

@ -802,6 +802,38 @@ void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value)
} }
} }
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node, void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensor) { size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node); const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
@ -814,11 +846,8 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &param
(void)input_tensor->emplace_back(nullptr); (void)input_tensor->emplace_back(nullptr);
return; return;
} }
auto value_tuple = utils::cast<ValueTuplePtr>(args[position]);
MS_EXCEPTION_IF_NULL(value_tuple);
auto value_tuple_value = value_tuple->value();
ValuePtrList flatted_value_tuple_value; ValuePtrList flatted_value_tuple_value;
FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value); FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) { if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << "."; << " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";

View File

@ -358,6 +358,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
.def(py::init<bool>(), py::arg("reverse")); .def(py::init<bool>(), py::arg("reverse"));
})); }));
namespace {
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) { bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
MS_EXCEPTION_IF_NULL(tuple); MS_EXCEPTION_IF_NULL(tuple);
for (size_t i = 0; i < tuple->size(); ++i) { for (size_t i = 0; i < tuple->size(); ++i) {
@ -370,17 +371,23 @@ bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
return true; return true;
} }
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
}
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) { bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
MS_EXCEPTION_IF_NULL(sequeue); MS_EXCEPTION_IF_NULL(sequeue);
return sequeue->size() > 1 && (*sequeue)[1] != nullptr && return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr && EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
(*sequeue)[1]->BuildType()->isa<Number>()) ||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
} }
namespace {
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue, void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) { const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
if (pos == nullptr) { if (pos == nullptr) {
@ -470,9 +477,8 @@ FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr
for (size_t i = 1; i < sequeue->size(); ++i) { for (size_t i = 1; i < sequeue->size(); ++i) {
if (tail_type_ == kGradAll) { if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]); MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr && EnableGradForScalar((*sequeue)[i])) {
(*sequeue)[i]->BuildType()->isa<Number>())) {
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))})); elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
} }
} else { } else {

View File

@ -43,6 +43,7 @@ using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractCOOTensor; using mindspore::abstract::AbstractCOOTensor;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractRowTensor;
@ -53,6 +54,7 @@ using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr; using mindspore::abstract::AbstractTuplePtr;
namespace { namespace {
static constexpr size_t kMaxSeqRecursiveDepth = 5;
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) { void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
if (cnode->size() != expect_size) { if (cnode->size() != expect_size) {
std::string op_name = GetCNodeFuncName(cnode); std::string op_name = GetCNodeFuncName(cnode);
@ -450,11 +452,60 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return std::make_shared<AbstractTuple>(std::move(elements)); return std::make_shared<AbstractTuple>(std::move(elements));
} }
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { // AbstractDictionary, AbstractClass --> AbstractSequence.
static AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
auto abs_seq = abs->cast<AbstractSequencePtr>();
if (abs_seq != nullptr) {
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertToAbstractSequence(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
if (changed_elements.empty()) {
// Here the AbstractList don't need to convert to AbstractTuple.
return nullptr;
}
// Always make new AbstractSequence when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
// Here the AbstractList don't need to convert to AbstractTuple.
if (abs_seq->isa<AbstractList>()) {
return std::make_shared<AbstractList>(std::move(elements));
} else {
return std::make_shared<AbstractTuple>(std::move(elements));
}
}
// AbstractDictionary --> AbstractTuple. // AbstractDictionary --> AbstractTuple.
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>(); auto abs_dict = abs->cast<AbstractDictionaryPtr>();
if (abs_dict != nullptr) { if (abs_dict != nullptr) {
return MakeAbstractTuple(abs_dict->elements()); const auto &dict_elements = abs_dict->elements();
std::vector<AbstractBasePtr> elements;
elements.reserve(dict_elements.size());
for (const auto &element : dict_elements) {
auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
if (new_element != nullptr) {
(void)elements.emplace_back(new_element);
} else {
(void)elements.emplace_back(element.second);
}
}
return std::make_shared<AbstractTuple>(elements);
} }
// AbstractClass --> AbstractTuple. // AbstractClass --> AbstractTuple.
auto abs_class = abs->cast<abstract::AbstractClassPtr>(); auto abs_class = abs->cast<abstract::AbstractClassPtr>();
@ -463,6 +514,11 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
} }
return nullptr; return nullptr;
} }
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractDictionary, AbstractClass --> AbstractSequence.
return ConvertToAbstractSequence(abs, 0);
}
}; };
// ================================================================== // ==================================================================
@ -597,12 +653,10 @@ class CleanAfterOptARewriter : public BaseRewriter {
return (this->*(iter->second))(cnode); return (this->*(iter->second))(cnode);
} }
static constexpr size_t kMaxListRecursiveDepth = 5;
// ValueList --> ValueTuple // ValueList --> ValueTuple
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) { static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
if (depth > kMaxListRecursiveDepth) { if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
} }
const auto &list_elements = value_list->value(); const auto &list_elements = value_list->value();
std::vector<ValuePtr> elements; std::vector<ValuePtr> elements;
@ -625,50 +679,60 @@ class CleanAfterOptARewriter : public BaseRewriter {
return nullptr; return nullptr;
} }
// AbstractSequence --> AbstractTuple // AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
static AbstractTuplePtr ConvertAbstractSeqToAbstractTuple(const AbstractSequencePtr &abs_seq, size_t depth) { static AbstractTuplePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
if (depth > kMaxListRecursiveDepth) { if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
} }
const auto &seq_elements = abs_seq->elements(); // AbstractList --> AbstractTuple.
// First we check if elements should be converted, auto abs_seq = abs->cast<AbstractSequencePtr>();
// changed_elements maps old element to new element. if (abs_seq != nullptr) {
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements; const auto &seq_elements = abs_seq->elements();
for (const auto &element : seq_elements) { // First we check if elements should be converted,
if (element->isa<AbstractSequence>()) { // changed_elements maps old element to new element.
auto new_element = ConvertAbstractSeqToAbstractTuple(element->cast<AbstractSequencePtr>(), depth + 1); mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertToAbstractTuple(element, depth + 1);
if (new_element != nullptr) { if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element); (void)changed_elements.emplace(element, new_element);
} }
} }
} if (changed_elements.empty()) {
if (changed_elements.empty()) { if (abs->isa<AbstractTuple>()) {
if (abs_seq->isa<AbstractTuple>()) { // If no elements changed and it is an AbstractTuple, do not convert.
// If no elements changed and it is an AbstractTuple, do not convert. return nullptr;
return nullptr; }
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<AbstractTuple>(seq_elements);
} }
// If no elements changed but it is not an AbstractTuple, convert it by copy elements. // Always make new AbstractTuple when elements changed.
return std::make_shared<AbstractTuple>(seq_elements); std::vector<AbstractBasePtr> elements;
} elements.reserve(seq_elements.size());
// Always make new AbstractTuple when elements changed. for (const auto &element : seq_elements) {
std::vector<AbstractBasePtr> elements; auto iter = changed_elements.find(element);
elements.reserve(seq_elements.size()); if (iter != changed_elements.end()) {
for (const auto &element : seq_elements) { (void)elements.emplace_back(iter->second);
auto iter = changed_elements.find(element); } else {
if (iter != changed_elements.end()) { (void)elements.emplace_back(element);
(void)elements.emplace_back(iter->second); }
} else {
(void)elements.emplace_back(element);
} }
return std::make_shared<AbstractTuple>(std::move(elements));
} }
return std::make_shared<AbstractTuple>(std::move(elements)); // AbstractDict --> AbstractTuple.
} auto abs_dict = abs->cast<AbstractDictionaryPtr>();
if (abs_dict != nullptr) {
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { const auto &dict_elements = abs_dict->elements();
// AbstractSequence --> AbstractTuple. std::vector<AbstractBasePtr> elements;
auto abs_seq = abs->cast<AbstractSequencePtr>(); elements.reserve(dict_elements.size());
if (abs_seq != nullptr) { for (const auto &element : dict_elements) {
return ConvertAbstractSeqToAbstractTuple(abs_seq, 0); auto new_element = ConvertToAbstractTuple(element.second, depth + 1);
if (new_element != nullptr) {
(void)elements.emplace_back(new_element);
} else {
(void)elements.emplace_back(element.second);
}
}
return std::make_shared<AbstractTuple>(elements);
} }
// AbstractCOOTensor --> AbstractTuple. // AbstractCOOTensor --> AbstractTuple.
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>(); auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
@ -685,6 +749,11 @@ class CleanAfterOptARewriter : public BaseRewriter {
} }
return nullptr; return nullptr;
} }
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
return ConvertToAbstractTuple(abs, 0);
}
}; };
} // namespace } // namespace

View File

@ -66,6 +66,15 @@ namespace pipeline {
namespace { namespace {
bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); } bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); }
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool EnableTupleBroaden(const abstract::AbstractBasePtr &abs) {
return abs->isa<abstract::AbstractTuple>() && abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors();
}
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_paras; std::vector<AnfNodePtr> new_paras;
@ -78,11 +87,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
} }
AbstractBasePtr par_abs = param_node->abstract(); AbstractBasePtr par_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(par_abs); MS_EXCEPTION_IF_NULL(par_abs);
if (par_abs->isa<abstract::AbstractUndetermined>() || if (par_abs->isa<abstract::AbstractUndetermined>() || par_abs->BuildValue() == kAnyValue ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr && EnableGradForScalar(par_abs) || EnableTupleBroaden(par_abs)) {
par_abs->BuildType()->isa<Number>()) ||
(par_abs->isa<abstract::AbstractTuple>() &&
par_abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
new_paras.push_back(param_node); new_paras.push_back(param_node);
} }
} }

View File

@ -32,6 +32,7 @@
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "ir/variable.h"
namespace mindspore { namespace mindspore {
namespace parse { namespace parse {
@ -506,6 +507,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>), std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>), std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>), std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
std::make_shared<ByTypeDataConverter<Variable>>(ObjCast<VariablePtr>),
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple), std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList), std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>), std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),

View File

@ -29,6 +29,7 @@
#include "utils/hash_map.h" #include "utils/hash_map.h"
#include "pybind_api/pybind_patch.h" #include "pybind_api/pybind_patch.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/variable.h"
#include "pipeline/jit/pass.h" #include "pipeline/jit/pass.h"
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/static_analysis/async_eval_result.h" #include "pipeline/jit/static_analysis/async_eval_result.h"
@ -149,7 +150,7 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) { AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() || bool broaden = value->isa<MetaTensor>() || value->isa<Variable>() ||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) || (enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>()); (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
@ -167,6 +168,14 @@ bool CheckArgValid(const py::handle &arg) {
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
} }
if (py::isinstance<Variable>(arg)) {
if (py::hasattr(arg, "value")) {
return CheckArgValid(arg.attr("value"));
}
MS_LOG(ERROR) << "There should be a python object value stored in the Variable " << py::str(arg);
return false;
}
if (py::isinstance<Tensor>(arg)) { if (py::isinstance<Tensor>(arg)) {
auto tensor = py::cast<TensorPtr>(arg); auto tensor = py::cast<TensorPtr>(arg);
if (tensor->data_type() == kNumberTypeBool) { if (tensor->data_type() == kNumberTypeBool) {

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/common/pybind_api/api_register.h"
#include "ir/variable.h"
#include "pipeline/jit/parse/data_converter.h"
namespace py = pybind11;
namespace mindspore {
REGISTER_PYBIND_DEFINE(Variable_, ([](const py::module *m) {
(void)py::class_<Variable, VariablePtr>(*m, "Variable_")
.def(py::init([](const py::object &py_value) {
ValuePtr real_value = nullptr;
if (!parse::ConvertData(py_value, &real_value)) {
MS_EXCEPTION(TypeError)
<< "Convert python object failed, the object type is " << py_value.get_type()
<< ", value is '" << py::str(py_value) << "'.";
}
return std::make_shared<Variable>(real_value);
}),
py::arg("py_value"));
}));
} // namespace mindspore

View File

@ -158,7 +158,7 @@ std::string AbstractBase::ToString(bool verbose) const {
AbstractBasePtr AbstractScalar::Broaden() const { AbstractBasePtr AbstractScalar::Broaden() const {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) { if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || value_mutable()) {
return AbstractBase::Broaden(); return AbstractBase::Broaden();
} }
auto type_id = GetTypeTrack()->type_id(); auto type_id = GetTypeTrack()->type_id();

View File

@ -187,6 +187,10 @@ class MS_CORE_API AbstractBase : public Base {
/// \return A pointer to the broadened abstract. /// \return A pointer to the broadened abstract.
virtual AbstractBasePtr PartialBroaden() const; virtual AbstractBasePtr PartialBroaden() const;
bool value_mutable() const { return value_mutable_; }
void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }
protected: protected:
/// \brief Build a value when value is not set. /// \brief Build a value when value is not set.
/// ///
@ -198,6 +202,7 @@ class MS_CORE_API AbstractBase : public Base {
TypePtr type_; TypePtr type_;
BaseShapePtr shape_; BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report std::string value_desc_; // store initial value description for error report
bool value_mutable_{false};
}; };
/// \brief Class AbstractScalar describes a scalar's type and value. /// \brief Class AbstractScalar describes a scalar's type and value.

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/variable.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace {
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
return;
}
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
if (abs_sequence != nullptr) {
const auto &elements = abs_sequence->elements();
for (auto &ele : elements) {
SetValueMutable(ele);
}
return;
}
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
if (abs_dict != nullptr) {
const auto &elements = abs_dict->elements();
for (auto &ele : elements) {
SetValueMutable(ele.second);
}
return;
}
abs->set_value_mutable(true);
}
} // namespace
abstract::AbstractBasePtr Variable::ToAbstract() {
if (real_value_ == nullptr) {
MS_LOG(EXCEPTION) << "Get abstract failed. The real value of Variable has not been set.";
}
auto abs = real_value_->ToAbstract();
SetValueMutable(abs);
return abs;
}
bool Variable::operator==(const Variable &other) const {
if (this == &other) {
return true;
}
auto other_real_value = other.real_value();
if (real_value_ == nullptr || other_real_value == nullptr) {
return false;
}
return *real_value_ == *other_real_value;
}
std::string Variable::ToString() const {
std::ostringstream oss;
if (real_value_ == nullptr) {
oss << "Variable(NULL)";
} else {
oss << "Variable(" << real_value_->ToString() << ")";
}
return oss.str();
}
std::string Variable::DumpText() const {
std::ostringstream oss;
if (real_value_ == nullptr) {
oss << type_name() << "(NULL)";
} else {
oss << type_name() << "(" << real_value_->DumpText() << ")";
}
return oss.str();
}
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_VARIABLE_H_
#define MINDSPORE_CORE_IR_VARIABLE_H_
#include <memory>
#include <string>
#include "ir/anf.h"
namespace mindspore {
class MS_CORE_API Variable : public Value {
public:
explicit Variable(const ValuePtr &real_value) : real_value_(real_value) {}
~Variable() override = default;
MS_DECLARE_PARENT(Variable, Value)
abstract::AbstractBasePtr ToAbstract() override;
const ValuePtr &real_value() const { return real_value_; }
bool operator==(const Variable &other) const;
bool operator==(const Value &other) const override {
if (other.isa<Variable>()) {
auto other_variable = static_cast<const Variable &>(other);
return *this == other_variable;
}
return false;
}
std::string ToString() const override;
std::string DumpText() const override;
private:
ValuePtr real_value_{nullptr};
};
using VariablePtr = std::shared_ptr<Variable>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_VARIABLE_H_

View File

@ -52,6 +52,14 @@ def TupleGetItem(x, index):
x = x.asnumpy() x = x.asnumpy()
y = x[index] y = x[index]
return Tensor(y) return Tensor(y)
if isinstance(x, dict):
count = 0
for value in x.values():
if count == index:
return value
count = count + 1
return x[index] return x[index]

View File

@ -25,6 +25,7 @@ from .dump import set_dump
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .seed import set_seed, get_seed from .seed import set_seed, get_seed
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
from .variable import Variable
# symbols from dtype # symbols from dtype
__all__ = [ __all__ = [
@ -59,5 +60,6 @@ __all__.extend([
"dtype", "_convert_data", "dtype", "_convert_data",
"set_seed", "get_seed", # random seed "set_seed", "get_seed", # random seed
"set_dump", "set_dump",
"ms_memory_recycle" "ms_memory_recycle",
"Variable"
]) ])

View File

@ -0,0 +1,76 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Variable class for setting constants mutable."""
from .._c_expression import Variable_
from ..common.tensor import Tensor, CSRTensor, COOTensor
class Variable(Variable_):
"""
Currently, all the inputs of Cell except Tensor such as scalar, tuple, list and dict, are regarded as constant
values. The constant values are non-differentiable and used to do constant folding in the optimization process.
We provide a class 'Variable' to store a constant value, to make the constant inputs of Cell 'mutable'.
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
thing is that it is differentiable from now on.
.. warning::
This is an experimental prototype that is subject to change or deletion.
Args:
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
Examples:
>>> import mindspore.nn as nn
>>> from mindspore.ops.composite import GradOperation
>>> from mindspore.common.variable import Variable
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
...
... def construct(self, x, y):
... return x * y
...
>>> class GradNet(nn.Cell):
... def __init__(self, net):
... super(GradNet, self).__init__()
... self.net = net
... self.grad_op = GradOperation()
...
... def construct(self, x, y):
... gradient_function = self.grad_op(self.net)
... return gradient_function(x, y)
...
>>> x = Variable(2)
>>> output = GradNet(Net())(x, 3)
>>> print(output)
3
"""
def __init__(self, value):
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
raise TypeError(
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
f"CSRTensor), but got {type(value).__name__}")
Variable_.__init__(self, value)
self._value = value
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self._value = value

View File

@ -33,6 +33,7 @@ from .._checkparam import Validator
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from ..common.variable import Variable
from ..common.tensor import Tensor, CSRTensor, COOTensor from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..ops.operations import Cast from ..ops.operations import Cast
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
@ -967,10 +968,10 @@ class Cell(Cell_):
if i.has_init: if i.has_init:
i.init_data() i.init_data()
new_inputs.append(i) new_inputs.append(i)
elif isinstance(i, COOTensor): elif isinstance(i, (COOTensor, CSRTensor)):
new_inputs.append(i)
elif isinstance(i, CSRTensor):
new_inputs.append(i) new_inputs.append(i)
elif isinstance(i, Variable):
new_inputs.append(i.value)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i) new_inputs.append(i)
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \ elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \

View File

@ -0,0 +1,497 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test getting gradient of Variable"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.composite import GradOperation
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore import Parameter, Variable
def compare(a, b):
if isinstance(a, (list, tuple)):
for aa, bb in zip(a, b):
if not compare(aa, bb):
return False
return True
return np.allclose(a.asnumpy(), b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a']
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple and list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested list and tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([(Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple and dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]['a']
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested dict and tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a'][0]
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested list and dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]['a']
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested dict and list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a'][0]
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)

View File

@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test variable"""
import numpy as np
from mindspore.ops.composite import GradOperation
from mindspore.common.variable import Variable
from mindspore.common.api import _CellGraphExecutor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import Parameter
def test_variable_scalar_mul_grad_first():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the first scalar input.
Expectation: Get the correct gradient.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x * y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable(2)
output = GradNet(Net())(x, 3)
assert output == 3
def test_variable_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to all scalar inputs.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x * y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable(2)
y = Variable(3)
output = GradNet(Net())(x, y)
assert output == (3, 2)
def test_variable_tuple_or_list_scalar_mul_grad():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the tuple or list scalar input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x):
return x[0] * x[1]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x):
gradient_function = self.grad_op(self.net)
return gradient_function(x)
x = Variable((2, 3))
output = GradNet(Net())(x)
assert output == (3, 2)
x = Variable([2, 3])
output = GradNet(Net())(x)
assert output == (3, 2)
def test_variable_dict_scalar_mul_grad():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the dict scalar input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x):
return x['a'] * x['b']
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x):
gradient_function = self.grad_op(self.net)
return gradient_function(x)
x = Variable({'a': 2, 'b': 3})
output = GradNet(Net())(x)
assert output == (3, 2)
def test_variable_mix_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the mix scalar input including dict and tuple.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x['a'] * x['b'] * y[0]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable({'a': 2, 'b': 3})
y = Variable((4, 5))
output = GradNet(Net())(x, y)
assert output == ((12, 8), (6, 0))
def test_tuple_inputs_compile_phase():
"""
Feature: Set Constants mutable.
Description: Test whether the compilation phase for tuple input twice are the same.
Expectation: The phases are the same.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, tuple_input):
x = tuple_input[0]
y = tuple_input[1]
x = x * self.z
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
phase1, _ = _cell_graph_executor.compile(net, (x, y))
phase2, _ = _cell_graph_executor.compile(net, (p, q))
assert phase1 != phase2
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
assert phase1 == phase2