diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index 9eab7662f67..c89a26982dd 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -802,6 +802,38 @@ void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) } } +void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) { + if (utils::isa(arg)) { + auto value_sequence = utils::cast(arg); + MS_EXCEPTION_IF_NULL(value_sequence); + auto sequence_value = value_sequence->value(); + for (auto &value : sequence_value) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + (void)flatted_value->emplace_back(value); + } else { + FlattenValue(value, flatted_value); + } + } + } else if (utils::isa(arg)) { + auto value_dict = utils::cast(arg); + MS_EXCEPTION_IF_NULL(value_dict); + auto dict_value = value_dict->value(); + for (auto &iter : dict_value) { + auto value = iter.second; + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + (void)flatted_value->emplace_back(value); + } else { + FlattenValue(value, flatted_value); + } + } + } else { + MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is " + << arg.ToString(); + } +} + void PushTupleTensor(const VectorRef &args, const std::vector ¶meters, const AnfNodePtr &front_node, size_t index, std::vector *input_tensor) { const auto &iter = std::find(parameters.begin(), parameters.end(), front_node); @@ -814,11 +846,8 @@ void PushTupleTensor(const VectorRef &args, const std::vector ¶m (void)input_tensor->emplace_back(nullptr); return; } - auto value_tuple = utils::cast(args[position]); - MS_EXCEPTION_IF_NULL(value_tuple); - auto value_tuple_value = value_tuple->value(); ValuePtrList flatted_value_tuple_value; - FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value); + FlattenValue(args[position], &flatted_value_tuple_value); if (index >= flatted_value_tuple_value.size()) { MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index << " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << "."; diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 39a4ad04746..d269942d1a1 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -358,6 +358,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { .def(py::init(), py::arg("reverse")); })); +namespace { bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) { MS_EXCEPTION_IF_NULL(tuple); for (size_t i = 0; i < tuple->size(); ++i) { @@ -370,17 +371,23 @@ bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) { return true; } +bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) { + return MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr && + abs->BuildType()->isa(); +} + +bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) { + return abs->isa() && enable_tuple_grad && + CheckSequenceAllTensor(abs->cast()); +} + bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) { MS_EXCEPTION_IF_NULL(sequeue); return sequeue->size() > 1 && (*sequeue)[1] != nullptr && - ((*sequeue)[1]->isa() || - (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr && - (*sequeue)[1]->BuildType()->isa()) || - ((*sequeue)[1]->isa() && enable_tuple_grad && - CheckSequenceAllTensor((*sequeue)[1]->cast()))); + ((*sequeue)[1]->isa() || (*sequeue)[1]->BuildValue() == kAnyValue || + EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad)); } -namespace { void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue, const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) { if (pos == nullptr) { @@ -470,9 +477,8 @@ FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr for (size_t i = 1; i < sequeue->size(); ++i) { if (tail_type_ == kGradAll) { MS_EXCEPTION_IF_NULL((*sequeue)[i]); - if ((*sequeue)[i]->isa() || - (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr && - (*sequeue)[i]->BuildType()->isa())) { + if ((*sequeue)[i]->isa() || (*sequeue)[i]->BuildValue() == kAnyValue || + EnableGradForScalar((*sequeue)[i])) { elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))})); } } else { diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index ece9875331e..7d337b4eb59 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -43,6 +43,7 @@ using mindspore::abstract::AbstractBasePtr; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractCOOTensor; using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractRowTensor; @@ -53,6 +54,7 @@ using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuplePtr; namespace { +static constexpr size_t kMaxSeqRecursiveDepth = 5; void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) { if (cnode->size() != expect_size) { std::string op_name = GetCNodeFuncName(cnode); @@ -450,11 +452,60 @@ class SimplifyDataStructuresRewriter : public BaseRewriter { return std::make_shared(std::move(elements)); } - AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { + // AbstractDictionary, AbstractClass --> AbstractSequence. + static AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) { + if (depth > kMaxSeqRecursiveDepth) { + MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels."; + } + auto abs_seq = abs->cast(); + if (abs_seq != nullptr) { + const auto &seq_elements = abs_seq->elements(); + // First we check if elements should be converted, + // changed_elements maps old element to new element. + mindspore::HashMap changed_elements; + for (const auto &element : seq_elements) { + auto new_element = ConvertToAbstractSequence(element, depth + 1); + if (new_element != nullptr) { + (void)changed_elements.emplace(element, new_element); + } + } + if (changed_elements.empty()) { + // Here the AbstractList don't need to convert to AbstractTuple. + return nullptr; + } + // Always make new AbstractSequence when elements changed. + std::vector elements; + elements.reserve(seq_elements.size()); + for (const auto &element : seq_elements) { + auto iter = changed_elements.find(element); + if (iter != changed_elements.end()) { + (void)elements.emplace_back(iter->second); + } else { + (void)elements.emplace_back(element); + } + } + // Here the AbstractList don't need to convert to AbstractTuple. + if (abs_seq->isa()) { + return std::make_shared(std::move(elements)); + } else { + return std::make_shared(std::move(elements)); + } + } // AbstractDictionary --> AbstractTuple. - auto abs_dict = abs->cast(); + auto abs_dict = abs->cast(); if (abs_dict != nullptr) { - return MakeAbstractTuple(abs_dict->elements()); + const auto &dict_elements = abs_dict->elements(); + std::vector elements; + elements.reserve(dict_elements.size()); + for (const auto &element : dict_elements) { + auto new_element = ConvertToAbstractSequence(element.second, depth + 1); + if (new_element != nullptr) { + (void)elements.emplace_back(new_element); + } else { + (void)elements.emplace_back(element.second); + } + } + return std::make_shared(elements); } // AbstractClass --> AbstractTuple. auto abs_class = abs->cast(); @@ -463,6 +514,11 @@ class SimplifyDataStructuresRewriter : public BaseRewriter { } return nullptr; } + + AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { + // AbstractDictionary, AbstractClass --> AbstractSequence. + return ConvertToAbstractSequence(abs, 0); + } }; // ================================================================== @@ -597,12 +653,10 @@ class CleanAfterOptARewriter : public BaseRewriter { return (this->*(iter->second))(cnode); } - static constexpr size_t kMaxListRecursiveDepth = 5; - // ValueList --> ValueTuple static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) { - if (depth > kMaxListRecursiveDepth) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; + if (depth > kMaxSeqRecursiveDepth) { + MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels."; } const auto &list_elements = value_list->value(); std::vector elements; @@ -625,50 +679,60 @@ class CleanAfterOptARewriter : public BaseRewriter { return nullptr; } - // AbstractSequence --> AbstractTuple - static AbstractTuplePtr ConvertAbstractSeqToAbstractTuple(const AbstractSequencePtr &abs_seq, size_t depth) { - if (depth > kMaxListRecursiveDepth) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; + // AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple. + static AbstractTuplePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) { + if (depth > kMaxSeqRecursiveDepth) { + MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels."; } - const auto &seq_elements = abs_seq->elements(); - // First we check if elements should be converted, - // changed_elements maps old element to new element. - mindspore::HashMap changed_elements; - for (const auto &element : seq_elements) { - if (element->isa()) { - auto new_element = ConvertAbstractSeqToAbstractTuple(element->cast(), depth + 1); + // AbstractList --> AbstractTuple. + auto abs_seq = abs->cast(); + if (abs_seq != nullptr) { + const auto &seq_elements = abs_seq->elements(); + // First we check if elements should be converted, + // changed_elements maps old element to new element. + mindspore::HashMap changed_elements; + for (const auto &element : seq_elements) { + auto new_element = ConvertToAbstractTuple(element, depth + 1); if (new_element != nullptr) { (void)changed_elements.emplace(element, new_element); } } - } - if (changed_elements.empty()) { - if (abs_seq->isa()) { - // If no elements changed and it is an AbstractTuple, do not convert. - return nullptr; + if (changed_elements.empty()) { + if (abs->isa()) { + // If no elements changed and it is an AbstractTuple, do not convert. + return nullptr; + } + // If no elements changed but it is not an AbstractTuple, convert it by copy elements. + return std::make_shared(seq_elements); } - // If no elements changed but it is not an AbstractTuple, convert it by copy elements. - return std::make_shared(seq_elements); - } - // Always make new AbstractTuple when elements changed. - std::vector elements; - elements.reserve(seq_elements.size()); - for (const auto &element : seq_elements) { - auto iter = changed_elements.find(element); - if (iter != changed_elements.end()) { - (void)elements.emplace_back(iter->second); - } else { - (void)elements.emplace_back(element); + // Always make new AbstractTuple when elements changed. + std::vector elements; + elements.reserve(seq_elements.size()); + for (const auto &element : seq_elements) { + auto iter = changed_elements.find(element); + if (iter != changed_elements.end()) { + (void)elements.emplace_back(iter->second); + } else { + (void)elements.emplace_back(element); + } } + return std::make_shared(std::move(elements)); } - return std::make_shared(std::move(elements)); - } - - AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { - // AbstractSequence --> AbstractTuple. - auto abs_seq = abs->cast(); - if (abs_seq != nullptr) { - return ConvertAbstractSeqToAbstractTuple(abs_seq, 0); + // AbstractDict --> AbstractTuple. + auto abs_dict = abs->cast(); + if (abs_dict != nullptr) { + const auto &dict_elements = abs_dict->elements(); + std::vector elements; + elements.reserve(dict_elements.size()); + for (const auto &element : dict_elements) { + auto new_element = ConvertToAbstractTuple(element.second, depth + 1); + if (new_element != nullptr) { + (void)elements.emplace_back(new_element); + } else { + (void)elements.emplace_back(element.second); + } + } + return std::make_shared(elements); } // AbstractCOOTensor --> AbstractTuple. auto abs_sparse = abs->cast(); @@ -685,6 +749,11 @@ class CleanAfterOptARewriter : public BaseRewriter { } return nullptr; } + + AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { + // AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple. + return ConvertToAbstractTuple(abs, 0); + } }; } // namespace diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 8315c7774f5..686fc840b1f 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -66,6 +66,15 @@ namespace pipeline { namespace { bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); } +bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) { + return MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr && + abs->BuildType()->isa(); +} + +bool EnableTupleBroaden(const abstract::AbstractBasePtr &abs) { + return abs->isa() && abs->cast()->ContainsAllBroadenTensors(); +} + void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); std::vector new_paras; @@ -78,11 +87,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { } AbstractBasePtr par_abs = param_node->abstract(); MS_EXCEPTION_IF_NULL(par_abs); - if (par_abs->isa() || - (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr && - par_abs->BuildType()->isa()) || - (par_abs->isa() && - par_abs->cast()->ContainsAllBroadenTensors())) { + if (par_abs->isa() || par_abs->BuildValue() == kAnyValue || + EnableGradForScalar(par_abs) || EnableTupleBroaden(par_abs)) { new_paras.push_back(param_node); } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index b38104ac969..6da9e08559c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -32,6 +32,7 @@ #include "utils/symbolic.h" #include "utils/ms_context.h" #include "include/common/utils/utils.h" +#include "ir/variable.h" namespace mindspore { namespace parse { @@ -506,6 +507,7 @@ static const std::vector &GetDataConverters() { std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), + std::make_shared>(ObjCast), std::make_shared>(ConvertTuple), std::make_shared>(ConvertList), std::make_shared>(PyCast), diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 1727d7f91c3..8e8ccf3ee61 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -29,6 +29,7 @@ #include "utils/hash_map.h" #include "pybind_api/pybind_patch.h" #include "ir/param_info.h" +#include "ir/variable.h" #include "pipeline/jit/pass.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/static_analysis/async_eval_result.h" @@ -149,7 +150,7 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) { AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) { MS_EXCEPTION_IF_NULL(value); - bool broaden = value->isa() || + bool broaden = value->isa() || value->isa() || (enable_tuple_broaden && value->isa() && CheckAllTensor(value->cast())) || (MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR) && value->isa()); @@ -167,6 +168,14 @@ bool CheckArgValid(const py::handle &arg) { return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); } + if (py::isinstance(arg)) { + if (py::hasattr(arg, "value")) { + return CheckArgValid(arg.attr("value")); + } + MS_LOG(ERROR) << "There should be a python object value stored in the Variable " << py::str(arg); + return false; + } + if (py::isinstance(arg)) { auto tensor = py::cast(arg); if (tensor->data_type() == kNumberTypeBool) { diff --git a/mindspore/ccsrc/pybind_api/ir/variable_py.cc b/mindspore/ccsrc/pybind_api/ir/variable_py.cc new file mode 100644 index 00000000000..cf37df323c0 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/ir/variable_py.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/common/pybind_api/api_register.h" +#include "ir/variable.h" +#include "pipeline/jit/parse/data_converter.h" + +namespace py = pybind11; +namespace mindspore { +REGISTER_PYBIND_DEFINE(Variable_, ([](const py::module *m) { + (void)py::class_(*m, "Variable_") + .def(py::init([](const py::object &py_value) { + ValuePtr real_value = nullptr; + if (!parse::ConvertData(py_value, &real_value)) { + MS_EXCEPTION(TypeError) + << "Convert python object failed, the object type is " << py_value.get_type() + << ", value is '" << py::str(py_value) << "'."; + } + return std::make_shared(real_value); + }), + py::arg("py_value")); + })); +} // namespace mindspore diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index e14de06955b..02b378a8ecb 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -158,7 +158,7 @@ std::string AbstractBase::ToString(bool verbose) const { AbstractBasePtr AbstractScalar::Broaden() const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - if (context->get_param(MS_CTX_GRAD_FOR_SCALAR)) { + if (context->get_param(MS_CTX_GRAD_FOR_SCALAR) || value_mutable()) { return AbstractBase::Broaden(); } auto type_id = GetTypeTrack()->type_id(); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index b4dc659d67a..0b6f225bedc 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -187,6 +187,10 @@ class MS_CORE_API AbstractBase : public Base { /// \return A pointer to the broadened abstract. virtual AbstractBasePtr PartialBroaden() const; + bool value_mutable() const { return value_mutable_; } + + void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; } + protected: /// \brief Build a value when value is not set. /// @@ -198,6 +202,7 @@ class MS_CORE_API AbstractBase : public Base { TypePtr type_; BaseShapePtr shape_; std::string value_desc_; // store initial value description for error report + bool value_mutable_{false}; }; /// \brief Class AbstractScalar describes a scalar's type and value. diff --git a/mindspore/core/ir/variable.cc b/mindspore/core/ir/variable.cc new file mode 100644 index 00000000000..ee404791d43 --- /dev/null +++ b/mindspore/core/ir/variable.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/variable.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace { +void SetValueMutable(const abstract::AbstractBasePtr &abs) { + MS_EXCEPTION_IF_NULL(abs); + if (abs->isa()) { + return; + } + + auto abs_sequence = abs->cast(); + if (abs_sequence != nullptr) { + const auto &elements = abs_sequence->elements(); + for (auto &ele : elements) { + SetValueMutable(ele); + } + return; + } + + auto abs_dict = abs->cast(); + if (abs_dict != nullptr) { + const auto &elements = abs_dict->elements(); + for (auto &ele : elements) { + SetValueMutable(ele.second); + } + return; + } + + abs->set_value_mutable(true); +} +} // namespace + +abstract::AbstractBasePtr Variable::ToAbstract() { + if (real_value_ == nullptr) { + MS_LOG(EXCEPTION) << "Get abstract failed. The real value of Variable has not been set."; + } + auto abs = real_value_->ToAbstract(); + SetValueMutable(abs); + return abs; +} + +bool Variable::operator==(const Variable &other) const { + if (this == &other) { + return true; + } + auto other_real_value = other.real_value(); + if (real_value_ == nullptr || other_real_value == nullptr) { + return false; + } + return *real_value_ == *other_real_value; +} + +std::string Variable::ToString() const { + std::ostringstream oss; + if (real_value_ == nullptr) { + oss << "Variable(NULL)"; + } else { + oss << "Variable(" << real_value_->ToString() << ")"; + } + return oss.str(); +} + +std::string Variable::DumpText() const { + std::ostringstream oss; + if (real_value_ == nullptr) { + oss << type_name() << "(NULL)"; + } else { + oss << type_name() << "(" << real_value_->DumpText() << ")"; + } + return oss.str(); +} +} // namespace mindspore diff --git a/mindspore/core/ir/variable.h b/mindspore/core/ir/variable.h new file mode 100644 index 00000000000..8769455167d --- /dev/null +++ b/mindspore/core/ir/variable.h @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_IR_VARIABLE_H_ +#define MINDSPORE_CORE_IR_VARIABLE_H_ + +#include +#include +#include "ir/anf.h" + +namespace mindspore { +class MS_CORE_API Variable : public Value { + public: + explicit Variable(const ValuePtr &real_value) : real_value_(real_value) {} + ~Variable() override = default; + MS_DECLARE_PARENT(Variable, Value) + + abstract::AbstractBasePtr ToAbstract() override; + + const ValuePtr &real_value() const { return real_value_; } + + bool operator==(const Variable &other) const; + + bool operator==(const Value &other) const override { + if (other.isa()) { + auto other_variable = static_cast(other); + return *this == other_variable; + } + return false; + } + + std::string ToString() const override; + + std::string DumpText() const override; + + private: + ValuePtr real_value_{nullptr}; +}; +using VariablePtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CORE_IR_VARIABLE_H_ diff --git a/mindspore/python/mindspore/_extends/builtin_operations.py b/mindspore/python/mindspore/_extends/builtin_operations.py index 7348e0322a4..1b0e4e9ad06 100644 --- a/mindspore/python/mindspore/_extends/builtin_operations.py +++ b/mindspore/python/mindspore/_extends/builtin_operations.py @@ -52,6 +52,14 @@ def TupleGetItem(x, index): x = x.asnumpy() y = x[index] return Tensor(y) + + if isinstance(x, dict): + count = 0 + for value in x.values(): + if count == index: + return value + count = count + 1 + return x[index] diff --git a/mindspore/python/mindspore/common/__init__.py b/mindspore/python/mindspore/common/__init__.py index 50b5554d665..5f71ea6896d 100644 --- a/mindspore/python/mindspore/common/__init__.py +++ b/mindspore/python/mindspore/common/__init__.py @@ -25,6 +25,7 @@ from .dump import set_dump from .parameter import Parameter, ParameterTuple from .seed import set_seed, get_seed from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor +from .variable import Variable # symbols from dtype __all__ = [ @@ -59,5 +60,6 @@ __all__.extend([ "dtype", "_convert_data", "set_seed", "get_seed", # random seed "set_dump", - "ms_memory_recycle" + "ms_memory_recycle", + "Variable" ]) diff --git a/mindspore/python/mindspore/common/variable.py b/mindspore/python/mindspore/common/variable.py new file mode 100644 index 00000000000..5ddf32b91cd --- /dev/null +++ b/mindspore/python/mindspore/common/variable.py @@ -0,0 +1,76 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Variable class for setting constants mutable.""" + +from .._c_expression import Variable_ +from ..common.tensor import Tensor, CSRTensor, COOTensor + + +class Variable(Variable_): + """ + Currently, all the inputs of Cell except Tensor such as scalar, tuple, list and dict, are regarded as constant + values. The constant values are non-differentiable and used to do constant folding in the optimization process. + We provide a class 'Variable' to store a constant value, to make the constant inputs of Cell 'mutable'. + A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important + thing is that it is differentiable from now on. + + .. warning:: + This is an experimental prototype that is subject to change or deletion. + + Args: + value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored. + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore.ops.composite import GradOperation + >>> from mindspore.common.variable import Variable + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... + ... def construct(self, x, y): + ... return x * y + ... + >>> class GradNet(nn.Cell): + ... def __init__(self, net): + ... super(GradNet, self).__init__() + ... self.net = net + ... self.grad_op = GradOperation() + ... + ... def construct(self, x, y): + ... gradient_function = self.grad_op(self.net) + ... return gradient_function(x, y) + ... + >>> x = Variable(2) + >>> output = GradNet(Net())(x, 3) + >>> print(output) + 3 + """ + + def __init__(self, value): + if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)): + raise TypeError( + f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, " + f"CSRTensor), but got {type(value).__name__}") + Variable_.__init__(self, value) + self._value = value + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 36f959ab0ad..6ce0f2e395d 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -33,6 +33,7 @@ from .._checkparam import Validator from ..common import dtype as mstype from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.parameter import Parameter, ParameterTuple +from ..common.variable import Variable from ..common.tensor import Tensor, CSRTensor, COOTensor from ..ops.operations import Cast from ..ops.primitive import Primitive @@ -967,10 +968,10 @@ class Cell(Cell_): if i.has_init: i.init_data() new_inputs.append(i) - elif isinstance(i, COOTensor): - new_inputs.append(i) - elif isinstance(i, CSRTensor): + elif isinstance(i, (COOTensor, CSRTensor)): new_inputs.append(i) + elif isinstance(i, Variable): + new_inputs.append(i.value) elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): new_inputs.append(i) elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \ diff --git a/tests/st/gradient/test_grad_variable.py b/tests/st/gradient/test_grad_variable.py new file mode 100644 index 00000000000..362f938286d --- /dev/null +++ b/tests/st/gradient/test_grad_variable.py @@ -0,0 +1,497 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test getting gradient of Variable""" +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.composite import GradOperation +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore import Parameter, Variable + + +def compare(a, b): + if isinstance(a, (list, tuple)): + for aa, bb in zip(a, b): + if not compare(aa, bb): + return False + return True + + return np.allclose(a.asnumpy(), b) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_tuple_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to tuple tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_list_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to list tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_dict_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to dict tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t['a'] + y = t['b'] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_tuple_tuple_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested tuple tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0][0] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable(((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)), + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_tuple_list_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested tuple and list tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0][0] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable(([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)], + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_list_tuple_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested list and tuple tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0][0] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable([(Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)), + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_tuple_dict_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested tuple and dict tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0]['a'] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + 'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)}, + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_dict_tuple_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested dict and tuple tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t['a'][0] + y = t['b'] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)), + 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_list_dict_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested list and dict tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t[0]['a'] + y = t[1] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + 'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)}, + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_grad_variable_dict_list_tensor(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to nested dict and list tensor input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, t): + x = t['a'][0] + y = t['b'] + x = x * self.z + out = self.matmul(x, y) + return out + + class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, z): + gradient_function = self.grad_op(self.net) + return gradient_function(z) + + t = Variable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)], + 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) + output = GradNetWrtX(Net())(t) + assert isinstance(output, tuple) + expect = [[np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], + [0, 0, 0]]).astype(np.float32)], + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert compare(output, expect) diff --git a/tests/ut/python/ir/test_variable.py b/tests/ut/python/ir/test_variable.py new file mode 100644 index 00000000000..0fbcaf9353c --- /dev/null +++ b/tests/ut/python/ir/test_variable.py @@ -0,0 +1,195 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test variable""" + +import numpy as np +from mindspore.ops.composite import GradOperation +from mindspore.common.variable import Variable +from mindspore.common.api import _CellGraphExecutor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import Parameter + + +def test_variable_scalar_mul_grad_first(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to the first scalar input. + Expectation: Get the correct gradient. + """ + + class Net(nn.Cell): + def construct(self, x, y): + return x * y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, x, y): + gradient_function = self.grad_op(self.net) + return gradient_function(x, y) + + x = Variable(2) + output = GradNet(Net())(x, 3) + assert output == 3 + + +def test_variable_scalar_mul_grad_all(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to all scalar inputs. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def construct(self, x, y): + return x * y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_op = GradOperation(get_all=True) + + def construct(self, x, y): + gradient_function = self.grad_op(self.net) + return gradient_function(x, y) + + x = Variable(2) + y = Variable(3) + output = GradNet(Net())(x, y) + assert output == (3, 2) + + +def test_variable_tuple_or_list_scalar_mul_grad(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to the tuple or list scalar input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def construct(self, x): + return x[0] * x[1] + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, x): + gradient_function = self.grad_op(self.net) + return gradient_function(x) + + x = Variable((2, 3)) + output = GradNet(Net())(x) + assert output == (3, 2) + + x = Variable([2, 3]) + output = GradNet(Net())(x) + assert output == (3, 2) + + +def test_variable_dict_scalar_mul_grad(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to the dict scalar input. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def construct(self, x): + return x['a'] * x['b'] + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_op = GradOperation() + + def construct(self, x): + gradient_function = self.grad_op(self.net) + return gradient_function(x) + + x = Variable({'a': 2, 'b': 3}) + output = GradNet(Net())(x) + assert output == (3, 2) + + +def test_variable_mix_scalar_mul_grad_all(): + """ + Feature: Set Constants mutable. + Description: Get gradient with respect to the mix scalar input including dict and tuple. + Expectation: Get the correct gradients. + """ + + class Net(nn.Cell): + def construct(self, x, y): + return x['a'] * x['b'] * y[0] + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_op = GradOperation(get_all=True) + + def construct(self, x, y): + gradient_function = self.grad_op(self.net) + return gradient_function(x, y) + + x = Variable({'a': 2, 'b': 3}) + y = Variable((4, 5)) + output = GradNet(Net())(x, y) + assert output == ((12, 8), (6, 0)) + + +def test_tuple_inputs_compile_phase(): + """ + Feature: Set Constants mutable. + Description: Test whether the compilation phase for tuple input twice are the same. + Expectation: The phases are the same. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.matmul = P.MatMul() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, tuple_input): + x = tuple_input[0] + y = tuple_input[1] + x = x * self.z + out = self.matmul(x, y) + return out + + x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) + y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) + p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) + q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32) + net = Net() + _cell_graph_executor = _CellGraphExecutor() + phase1, _ = _cell_graph_executor.compile(net, (x, y)) + phase2, _ = _cell_graph_executor.compile(net, (p, q)) + assert phase1 != phase2 + phase1, _ = _cell_graph_executor.compile(net, Variable((x, y))) + phase2, _ = _cell_graph_executor.compile(net, Variable((p, q))) + assert phase1 == phase2