!32373 Add class Variable to set constant mutable

Merge pull request !32373 from YuJianfeng/mutable
This commit is contained in:
i-robot 2022-04-02 07:35:02 +00:00 committed by Gitee
commit b12fefbded
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 1152 additions and 67 deletions

View File

@ -802,6 +802,38 @@ void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value)
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
@ -814,11 +846,8 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &param
(void)input_tensor->emplace_back(nullptr);
return;
}
auto value_tuple = utils::cast<ValueTuplePtr>(args[position]);
MS_EXCEPTION_IF_NULL(value_tuple);
auto value_tuple_value = value_tuple->value();
ValuePtrList flatted_value_tuple_value;
FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value);
FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";

View File

@ -358,6 +358,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
.def(py::init<bool>(), py::arg("reverse"));
}));
namespace {
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
MS_EXCEPTION_IF_NULL(tuple);
for (size_t i = 0; i < tuple->size(); ++i) {
@ -370,17 +371,23 @@ bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
return true;
}
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
}
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
MS_EXCEPTION_IF_NULL(sequeue);
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()) ||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
}
namespace {
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
if (pos == nullptr) {
@ -470,9 +477,8 @@ FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr
for (size_t i = 1; i < sequeue->size(); ++i) {
if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
(*sequeue)[i]->BuildType()->isa<Number>())) {
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
EnableGradForScalar((*sequeue)[i])) {
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
} else {

View File

@ -43,6 +43,7 @@ using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractCOOTensor;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor;
@ -53,6 +54,7 @@ using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
namespace {
static constexpr size_t kMaxSeqRecursiveDepth = 5;
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
if (cnode->size() != expect_size) {
std::string op_name = GetCNodeFuncName(cnode);
@ -450,11 +452,60 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return std::make_shared<AbstractTuple>(std::move(elements));
}
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractDictionary, AbstractClass --> AbstractSequence.
static AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
auto abs_seq = abs->cast<AbstractSequencePtr>();
if (abs_seq != nullptr) {
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertToAbstractSequence(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
if (changed_elements.empty()) {
// Here the AbstractList don't need to convert to AbstractTuple.
return nullptr;
}
// Always make new AbstractSequence when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
// Here the AbstractList don't need to convert to AbstractTuple.
if (abs_seq->isa<AbstractList>()) {
return std::make_shared<AbstractList>(std::move(elements));
} else {
return std::make_shared<AbstractTuple>(std::move(elements));
}
}
// AbstractDictionary --> AbstractTuple.
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
if (abs_dict != nullptr) {
return MakeAbstractTuple(abs_dict->elements());
const auto &dict_elements = abs_dict->elements();
std::vector<AbstractBasePtr> elements;
elements.reserve(dict_elements.size());
for (const auto &element : dict_elements) {
auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
if (new_element != nullptr) {
(void)elements.emplace_back(new_element);
} else {
(void)elements.emplace_back(element.second);
}
}
return std::make_shared<AbstractTuple>(elements);
}
// AbstractClass --> AbstractTuple.
auto abs_class = abs->cast<abstract::AbstractClassPtr>();
@ -463,6 +514,11 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
}
return nullptr;
}
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractDictionary, AbstractClass --> AbstractSequence.
return ConvertToAbstractSequence(abs, 0);
}
};
// ==================================================================
@ -597,12 +653,10 @@ class CleanAfterOptARewriter : public BaseRewriter {
return (this->*(iter->second))(cnode);
}
static constexpr size_t kMaxListRecursiveDepth = 5;
// ValueList --> ValueTuple
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
if (depth > kMaxListRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
const auto &list_elements = value_list->value();
std::vector<ValuePtr> elements;
@ -625,50 +679,60 @@ class CleanAfterOptARewriter : public BaseRewriter {
return nullptr;
}
// AbstractSequence --> AbstractTuple
static AbstractTuplePtr ConvertAbstractSeqToAbstractTuple(const AbstractSequencePtr &abs_seq, size_t depth) {
if (depth > kMaxListRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
static AbstractTuplePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
if (element->isa<AbstractSequence>()) {
auto new_element = ConvertAbstractSeqToAbstractTuple(element->cast<AbstractSequencePtr>(), depth + 1);
// AbstractList --> AbstractTuple.
auto abs_seq = abs->cast<AbstractSequencePtr>();
if (abs_seq != nullptr) {
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertToAbstractTuple(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
}
if (changed_elements.empty()) {
if (abs_seq->isa<AbstractTuple>()) {
// If no elements changed and it is an AbstractTuple, do not convert.
return nullptr;
if (changed_elements.empty()) {
if (abs->isa<AbstractTuple>()) {
// If no elements changed and it is an AbstractTuple, do not convert.
return nullptr;
}
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<AbstractTuple>(seq_elements);
}
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<AbstractTuple>(seq_elements);
}
// Always make new AbstractTuple when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
// Always make new AbstractTuple when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
return std::make_shared<AbstractTuple>(std::move(elements));
}
return std::make_shared<AbstractTuple>(std::move(elements));
}
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractSequence --> AbstractTuple.
auto abs_seq = abs->cast<AbstractSequencePtr>();
if (abs_seq != nullptr) {
return ConvertAbstractSeqToAbstractTuple(abs_seq, 0);
// AbstractDict --> AbstractTuple.
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
if (abs_dict != nullptr) {
const auto &dict_elements = abs_dict->elements();
std::vector<AbstractBasePtr> elements;
elements.reserve(dict_elements.size());
for (const auto &element : dict_elements) {
auto new_element = ConvertToAbstractTuple(element.second, depth + 1);
if (new_element != nullptr) {
(void)elements.emplace_back(new_element);
} else {
(void)elements.emplace_back(element.second);
}
}
return std::make_shared<AbstractTuple>(elements);
}
// AbstractCOOTensor --> AbstractTuple.
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
@ -685,6 +749,11 @@ class CleanAfterOptARewriter : public BaseRewriter {
}
return nullptr;
}
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
return ConvertToAbstractTuple(abs, 0);
}
};
} // namespace

View File

@ -66,6 +66,15 @@ namespace pipeline {
namespace {
bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); }
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool EnableTupleBroaden(const abstract::AbstractBasePtr &abs) {
return abs->isa<abstract::AbstractTuple>() && abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors();
}
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_paras;
@ -78,11 +87,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
}
AbstractBasePtr par_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(par_abs);
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>()) ||
(par_abs->isa<abstract::AbstractTuple>() &&
par_abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
if (par_abs->isa<abstract::AbstractUndetermined>() || par_abs->BuildValue() == kAnyValue ||
EnableGradForScalar(par_abs) || EnableTupleBroaden(par_abs)) {
new_paras.push_back(param_node);
}
}

View File

@ -32,6 +32,7 @@
#include "utils/symbolic.h"
#include "utils/ms_context.h"
#include "include/common/utils/utils.h"
#include "ir/variable.h"
namespace mindspore {
namespace parse {
@ -506,6 +507,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
std::make_shared<ByTypeDataConverter<Variable>>(ObjCast<VariablePtr>),
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),

View File

@ -29,6 +29,7 @@
#include "utils/hash_map.h"
#include "pybind_api/pybind_patch.h"
#include "ir/param_info.h"
#include "ir/variable.h"
#include "pipeline/jit/pass.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
@ -149,7 +150,7 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) {
MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() ||
bool broaden = value->isa<MetaTensor>() || value->isa<Variable>() ||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
@ -167,6 +168,14 @@ bool CheckArgValid(const py::handle &arg) {
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
}
if (py::isinstance<Variable>(arg)) {
if (py::hasattr(arg, "value")) {
return CheckArgValid(arg.attr("value"));
}
MS_LOG(ERROR) << "There should be a python object value stored in the Variable " << py::str(arg);
return false;
}
if (py::isinstance<Tensor>(arg)) {
auto tensor = py::cast<TensorPtr>(arg);
if (tensor->data_type() == kNumberTypeBool) {

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/common/pybind_api/api_register.h"
#include "ir/variable.h"
#include "pipeline/jit/parse/data_converter.h"
namespace py = pybind11;
namespace mindspore {
REGISTER_PYBIND_DEFINE(Variable_, ([](const py::module *m) {
(void)py::class_<Variable, VariablePtr>(*m, "Variable_")
.def(py::init([](const py::object &py_value) {
ValuePtr real_value = nullptr;
if (!parse::ConvertData(py_value, &real_value)) {
MS_EXCEPTION(TypeError)
<< "Convert python object failed, the object type is " << py_value.get_type()
<< ", value is '" << py::str(py_value) << "'.";
}
return std::make_shared<Variable>(real_value);
}),
py::arg("py_value"));
}));
} // namespace mindspore

View File

@ -158,7 +158,7 @@ std::string AbstractBase::ToString(bool verbose) const {
AbstractBasePtr AbstractScalar::Broaden() const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || value_mutable()) {
return AbstractBase::Broaden();
}
auto type_id = GetTypeTrack()->type_id();

View File

@ -187,6 +187,10 @@ class MS_CORE_API AbstractBase : public Base {
/// \return A pointer to the broadened abstract.
virtual AbstractBasePtr PartialBroaden() const;
bool value_mutable() const { return value_mutable_; }
void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }
protected:
/// \brief Build a value when value is not set.
///
@ -198,6 +202,7 @@ class MS_CORE_API AbstractBase : public Base {
TypePtr type_;
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
bool value_mutable_{false};
};
/// \brief Class AbstractScalar describes a scalar's type and value.

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/variable.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace {
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
return;
}
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
if (abs_sequence != nullptr) {
const auto &elements = abs_sequence->elements();
for (auto &ele : elements) {
SetValueMutable(ele);
}
return;
}
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
if (abs_dict != nullptr) {
const auto &elements = abs_dict->elements();
for (auto &ele : elements) {
SetValueMutable(ele.second);
}
return;
}
abs->set_value_mutable(true);
}
} // namespace
abstract::AbstractBasePtr Variable::ToAbstract() {
if (real_value_ == nullptr) {
MS_LOG(EXCEPTION) << "Get abstract failed. The real value of Variable has not been set.";
}
auto abs = real_value_->ToAbstract();
SetValueMutable(abs);
return abs;
}
bool Variable::operator==(const Variable &other) const {
if (this == &other) {
return true;
}
auto other_real_value = other.real_value();
if (real_value_ == nullptr || other_real_value == nullptr) {
return false;
}
return *real_value_ == *other_real_value;
}
std::string Variable::ToString() const {
std::ostringstream oss;
if (real_value_ == nullptr) {
oss << "Variable(NULL)";
} else {
oss << "Variable(" << real_value_->ToString() << ")";
}
return oss.str();
}
std::string Variable::DumpText() const {
std::ostringstream oss;
if (real_value_ == nullptr) {
oss << type_name() << "(NULL)";
} else {
oss << type_name() << "(" << real_value_->DumpText() << ")";
}
return oss.str();
}
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_VARIABLE_H_
#define MINDSPORE_CORE_IR_VARIABLE_H_
#include <memory>
#include <string>
#include "ir/anf.h"
namespace mindspore {
class MS_CORE_API Variable : public Value {
public:
explicit Variable(const ValuePtr &real_value) : real_value_(real_value) {}
~Variable() override = default;
MS_DECLARE_PARENT(Variable, Value)
abstract::AbstractBasePtr ToAbstract() override;
const ValuePtr &real_value() const { return real_value_; }
bool operator==(const Variable &other) const;
bool operator==(const Value &other) const override {
if (other.isa<Variable>()) {
auto other_variable = static_cast<const Variable &>(other);
return *this == other_variable;
}
return false;
}
std::string ToString() const override;
std::string DumpText() const override;
private:
ValuePtr real_value_{nullptr};
};
using VariablePtr = std::shared_ptr<Variable>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_VARIABLE_H_

View File

@ -52,6 +52,14 @@ def TupleGetItem(x, index):
x = x.asnumpy()
y = x[index]
return Tensor(y)
if isinstance(x, dict):
count = 0
for value in x.values():
if count == index:
return value
count = count + 1
return x[index]

View File

@ -25,6 +25,7 @@ from .dump import set_dump
from .parameter import Parameter, ParameterTuple
from .seed import set_seed, get_seed
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
from .variable import Variable
# symbols from dtype
__all__ = [
@ -59,5 +60,6 @@ __all__.extend([
"dtype", "_convert_data",
"set_seed", "get_seed", # random seed
"set_dump",
"ms_memory_recycle"
"ms_memory_recycle",
"Variable"
])

View File

@ -0,0 +1,76 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Variable class for setting constants mutable."""
from .._c_expression import Variable_
from ..common.tensor import Tensor, CSRTensor, COOTensor
class Variable(Variable_):
"""
Currently, all the inputs of Cell except Tensor such as scalar, tuple, list and dict, are regarded as constant
values. The constant values are non-differentiable and used to do constant folding in the optimization process.
We provide a class 'Variable' to store a constant value, to make the constant inputs of Cell 'mutable'.
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
thing is that it is differentiable from now on.
.. warning::
This is an experimental prototype that is subject to change or deletion.
Args:
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
Examples:
>>> import mindspore.nn as nn
>>> from mindspore.ops.composite import GradOperation
>>> from mindspore.common.variable import Variable
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
...
... def construct(self, x, y):
... return x * y
...
>>> class GradNet(nn.Cell):
... def __init__(self, net):
... super(GradNet, self).__init__()
... self.net = net
... self.grad_op = GradOperation()
...
... def construct(self, x, y):
... gradient_function = self.grad_op(self.net)
... return gradient_function(x, y)
...
>>> x = Variable(2)
>>> output = GradNet(Net())(x, 3)
>>> print(output)
3
"""
def __init__(self, value):
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
raise TypeError(
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
f"CSRTensor), but got {type(value).__name__}")
Variable_.__init__(self, value)
self._value = value
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self._value = value

View File

@ -33,6 +33,7 @@ from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from ..common.parameter import Parameter, ParameterTuple
from ..common.variable import Variable
from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..ops.operations import Cast
from ..ops.primitive import Primitive
@ -967,10 +968,10 @@ class Cell(Cell_):
if i.has_init:
i.init_data()
new_inputs.append(i)
elif isinstance(i, COOTensor):
new_inputs.append(i)
elif isinstance(i, CSRTensor):
elif isinstance(i, (COOTensor, CSRTensor)):
new_inputs.append(i)
elif isinstance(i, Variable):
new_inputs.append(i.value)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \

View File

@ -0,0 +1,497 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test getting gradient of Variable"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.composite import GradOperation
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore import Parameter, Variable
def compare(a, b):
if isinstance(a, (list, tuple)):
for aa, bb in zip(a, b):
if not compare(aa, bb):
return False
return True
return np.allclose(a.asnumpy(), b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a']
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple and list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested list and tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0][0]
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([(Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_tuple_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested tuple and dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]['a']
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_tuple_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested dict and tuple tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a'][0]
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_list_dict_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested list and dict tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t[0]['a']
y = t[1]
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_variable_dict_list_tensor():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to nested dict and list tensor input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, t):
x = t['a'][0]
y = t['b']
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, z):
gradient_function = self.grad_op(self.net)
return gradient_function(z)
t = Variable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)

View File

@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test variable"""
import numpy as np
from mindspore.ops.composite import GradOperation
from mindspore.common.variable import Variable
from mindspore.common.api import _CellGraphExecutor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import Parameter
def test_variable_scalar_mul_grad_first():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the first scalar input.
Expectation: Get the correct gradient.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x * y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable(2)
output = GradNet(Net())(x, 3)
assert output == 3
def test_variable_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to all scalar inputs.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x * y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable(2)
y = Variable(3)
output = GradNet(Net())(x, y)
assert output == (3, 2)
def test_variable_tuple_or_list_scalar_mul_grad():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the tuple or list scalar input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x):
return x[0] * x[1]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x):
gradient_function = self.grad_op(self.net)
return gradient_function(x)
x = Variable((2, 3))
output = GradNet(Net())(x)
assert output == (3, 2)
x = Variable([2, 3])
output = GradNet(Net())(x)
assert output == (3, 2)
def test_variable_dict_scalar_mul_grad():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the dict scalar input.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x):
return x['a'] * x['b']
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation()
def construct(self, x):
gradient_function = self.grad_op(self.net)
return gradient_function(x)
x = Variable({'a': 2, 'b': 3})
output = GradNet(Net())(x)
assert output == (3, 2)
def test_variable_mix_scalar_mul_grad_all():
"""
Feature: Set Constants mutable.
Description: Get gradient with respect to the mix scalar input including dict and tuple.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def construct(self, x, y):
return x['a'] * x['b'] * y[0]
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Variable({'a': 2, 'b': 3})
y = Variable((4, 5))
output = GradNet(Net())(x, y)
assert output == ((12, 8), (6, 0))
def test_tuple_inputs_compile_phase():
"""
Feature: Set Constants mutable.
Description: Test whether the compilation phase for tuple input twice are the same.
Expectation: The phases are the same.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, tuple_input):
x = tuple_input[0]
y = tuple_input[1]
x = x * self.z
out = self.matmul(x, y)
return out
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
net = Net()
_cell_graph_executor = _CellGraphExecutor()
phase1, _ = _cell_graph_executor.compile(net, (x, y))
phase2, _ = _cell_graph_executor.compile(net, (p, q))
assert phase1 != phase2
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
assert phase1 == phase2