forked from mindspore-Ecosystem/mindspore
!32373 Add class Variable to set constant mutable
Merge pull request !32373 from YuJianfeng/mutable
This commit is contained in:
commit
b12fefbded
|
@ -802,6 +802,38 @@ void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value)
|
|||
}
|
||||
}
|
||||
|
||||
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
|
||||
if (utils::isa<ValueSequencePtr>(arg)) {
|
||||
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||
auto sequence_value = value_sequence->value();
|
||||
for (auto &value : sequence_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
|
||||
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_dict);
|
||||
auto dict_value = value_dict->value();
|
||||
for (auto &iter : dict_value) {
|
||||
auto value = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
|
||||
<< arg.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
||||
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
|
||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
||||
|
@ -814,11 +846,8 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶m
|
|||
(void)input_tensor->emplace_back(nullptr);
|
||||
return;
|
||||
}
|
||||
auto value_tuple = utils::cast<ValueTuplePtr>(args[position]);
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
auto value_tuple_value = value_tuple->value();
|
||||
ValuePtrList flatted_value_tuple_value;
|
||||
FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value);
|
||||
FlattenValue(args[position], &flatted_value_tuple_value);
|
||||
if (index >= flatted_value_tuple_value.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
|
||||
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
|
||||
|
|
|
@ -358,6 +358,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
|||
.def(py::init<bool>(), py::arg("reverse"));
|
||||
}));
|
||||
|
||||
namespace {
|
||||
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||
|
@ -370,17 +371,23 @@ bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
|
||||
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||
abs->BuildType()->isa<Number>();
|
||||
}
|
||||
|
||||
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
|
||||
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
||||
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
|
||||
}
|
||||
|
||||
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
|
||||
(*sequeue)[1]->BuildType()->isa<Number>()) ||
|
||||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
||||
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
|
||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
|
||||
}
|
||||
|
||||
namespace {
|
||||
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
|
||||
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
|
||||
if (pos == nullptr) {
|
||||
|
@ -470,9 +477,8 @@ FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr
|
|||
for (size_t i = 1; i < sequeue->size(); ++i) {
|
||||
if (tail_type_ == kGradAll) {
|
||||
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
|
||||
(*sequeue)[i]->BuildType()->isa<Number>())) {
|
||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*sequeue)[i])) {
|
||||
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -43,6 +43,7 @@ using mindspore::abstract::AbstractBasePtr;
|
|||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractCOOTensor;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractListPtr;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
|
@ -53,6 +54,7 @@ using mindspore::abstract::AbstractTuple;
|
|||
using mindspore::abstract::AbstractTuplePtr;
|
||||
|
||||
namespace {
|
||||
static constexpr size_t kMaxSeqRecursiveDepth = 5;
|
||||
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
|
||||
if (cnode->size() != expect_size) {
|
||||
std::string op_name = GetCNodeFuncName(cnode);
|
||||
|
@ -450,11 +452,60 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
|
||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||
// AbstractDictionary, AbstractClass --> AbstractSequence.
|
||||
static AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
|
||||
if (depth > kMaxSeqRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||
}
|
||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
if (abs_seq != nullptr) {
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
auto new_element = ConvertToAbstractSequence(element, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)changed_elements.emplace(element, new_element);
|
||||
}
|
||||
}
|
||||
if (changed_elements.empty()) {
|
||||
// Here the AbstractList don't need to convert to AbstractTuple.
|
||||
return nullptr;
|
||||
}
|
||||
// Always make new AbstractSequence when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
// Here the AbstractList don't need to convert to AbstractTuple.
|
||||
if (abs_seq->isa<AbstractList>()) {
|
||||
return std::make_shared<AbstractList>(std::move(elements));
|
||||
} else {
|
||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
}
|
||||
// AbstractDictionary --> AbstractTuple.
|
||||
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
|
||||
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
|
||||
if (abs_dict != nullptr) {
|
||||
return MakeAbstractTuple(abs_dict->elements());
|
||||
const auto &dict_elements = abs_dict->elements();
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(dict_elements.size());
|
||||
for (const auto &element : dict_elements) {
|
||||
auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)elements.emplace_back(new_element);
|
||||
} else {
|
||||
(void)elements.emplace_back(element.second);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(elements);
|
||||
}
|
||||
// AbstractClass --> AbstractTuple.
|
||||
auto abs_class = abs->cast<abstract::AbstractClassPtr>();
|
||||
|
@ -463,6 +514,11 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||
// AbstractDictionary, AbstractClass --> AbstractSequence.
|
||||
return ConvertToAbstractSequence(abs, 0);
|
||||
}
|
||||
};
|
||||
|
||||
// ==================================================================
|
||||
|
@ -597,12 +653,10 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return (this->*(iter->second))(cnode);
|
||||
}
|
||||
|
||||
static constexpr size_t kMaxListRecursiveDepth = 5;
|
||||
|
||||
// ValueList --> ValueTuple
|
||||
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
|
||||
if (depth > kMaxListRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
||||
if (depth > kMaxSeqRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||
}
|
||||
const auto &list_elements = value_list->value();
|
||||
std::vector<ValuePtr> elements;
|
||||
|
@ -625,50 +679,60 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// AbstractSequence --> AbstractTuple
|
||||
static AbstractTuplePtr ConvertAbstractSeqToAbstractTuple(const AbstractSequencePtr &abs_seq, size_t depth) {
|
||||
if (depth > kMaxListRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
||||
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
|
||||
static AbstractTuplePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
|
||||
if (depth > kMaxSeqRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||
}
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
if (element->isa<AbstractSequence>()) {
|
||||
auto new_element = ConvertAbstractSeqToAbstractTuple(element->cast<AbstractSequencePtr>(), depth + 1);
|
||||
// AbstractList --> AbstractTuple.
|
||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
if (abs_seq != nullptr) {
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
auto new_element = ConvertToAbstractTuple(element, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)changed_elements.emplace(element, new_element);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (changed_elements.empty()) {
|
||||
if (abs_seq->isa<AbstractTuple>()) {
|
||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||
return nullptr;
|
||||
if (changed_elements.empty()) {
|
||||
if (abs->isa<AbstractTuple>()) {
|
||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||
return nullptr;
|
||||
}
|
||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||
return std::make_shared<AbstractTuple>(seq_elements);
|
||||
}
|
||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||
return std::make_shared<AbstractTuple>(seq_elements);
|
||||
}
|
||||
// Always make new AbstractTuple when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
// Always make new AbstractTuple when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
|
||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||
// AbstractSequence --> AbstractTuple.
|
||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
if (abs_seq != nullptr) {
|
||||
return ConvertAbstractSeqToAbstractTuple(abs_seq, 0);
|
||||
// AbstractDict --> AbstractTuple.
|
||||
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
|
||||
if (abs_dict != nullptr) {
|
||||
const auto &dict_elements = abs_dict->elements();
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(dict_elements.size());
|
||||
for (const auto &element : dict_elements) {
|
||||
auto new_element = ConvertToAbstractTuple(element.second, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)elements.emplace_back(new_element);
|
||||
} else {
|
||||
(void)elements.emplace_back(element.second);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(elements);
|
||||
}
|
||||
// AbstractCOOTensor --> AbstractTuple.
|
||||
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
|
||||
|
@ -685,6 +749,11 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
|
||||
return ConvertToAbstractTuple(abs, 0);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -66,6 +66,15 @@ namespace pipeline {
|
|||
namespace {
|
||||
bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); }
|
||||
|
||||
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
|
||||
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||
abs->BuildType()->isa<Number>();
|
||||
}
|
||||
|
||||
bool EnableTupleBroaden(const abstract::AbstractBasePtr &abs) {
|
||||
return abs->isa<abstract::AbstractTuple>() && abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors();
|
||||
}
|
||||
|
||||
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_paras;
|
||||
|
@ -78,11 +87,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
AbstractBasePtr par_abs = param_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(par_abs);
|
||||
if (par_abs->isa<abstract::AbstractUndetermined>() ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
|
||||
par_abs->BuildType()->isa<Number>()) ||
|
||||
(par_abs->isa<abstract::AbstractTuple>() &&
|
||||
par_abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
|
||||
if (par_abs->isa<abstract::AbstractUndetermined>() || par_abs->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar(par_abs) || EnableTupleBroaden(par_abs)) {
|
||||
new_paras.push_back(param_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "ir/variable.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -506,6 +507,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<Variable>>(ObjCast<VariablePtr>),
|
||||
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
||||
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
||||
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "utils/hash_map.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/variable.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
|
@ -149,7 +150,7 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
|
|||
|
||||
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
bool broaden = value->isa<MetaTensor>() ||
|
||||
bool broaden = value->isa<MetaTensor>() || value->isa<Variable>() ||
|
||||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
|
||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
||||
|
||||
|
@ -167,6 +168,14 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
||||
}
|
||||
|
||||
if (py::isinstance<Variable>(arg)) {
|
||||
if (py::hasattr(arg, "value")) {
|
||||
return CheckArgValid(arg.attr("value"));
|
||||
}
|
||||
MS_LOG(ERROR) << "There should be a python object value stored in the Variable " << py::str(arg);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (py::isinstance<Tensor>(arg)) {
|
||||
auto tensor = py::cast<TensorPtr>(arg);
|
||||
if (tensor->data_type() == kNumberTypeBool) {
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
#include "ir/variable.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
REGISTER_PYBIND_DEFINE(Variable_, ([](const py::module *m) {
|
||||
(void)py::class_<Variable, VariablePtr>(*m, "Variable_")
|
||||
.def(py::init([](const py::object &py_value) {
|
||||
ValuePtr real_value = nullptr;
|
||||
if (!parse::ConvertData(py_value, &real_value)) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "Convert python object failed, the object type is " << py_value.get_type()
|
||||
<< ", value is '" << py::str(py_value) << "'.";
|
||||
}
|
||||
return std::make_shared<Variable>(real_value);
|
||||
}),
|
||||
py::arg("py_value"));
|
||||
}));
|
||||
} // namespace mindspore
|
|
@ -158,7 +158,7 @@ std::string AbstractBase::ToString(bool verbose) const {
|
|||
AbstractBasePtr AbstractScalar::Broaden() const {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || value_mutable()) {
|
||||
return AbstractBase::Broaden();
|
||||
}
|
||||
auto type_id = GetTypeTrack()->type_id();
|
||||
|
|
|
@ -187,6 +187,10 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
/// \return A pointer to the broadened abstract.
|
||||
virtual AbstractBasePtr PartialBroaden() const;
|
||||
|
||||
bool value_mutable() const { return value_mutable_; }
|
||||
|
||||
void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }
|
||||
|
||||
protected:
|
||||
/// \brief Build a value when value is not set.
|
||||
///
|
||||
|
@ -198,6 +202,7 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
TypePtr type_;
|
||||
BaseShapePtr shape_;
|
||||
std::string value_desc_; // store initial value description for error report
|
||||
bool value_mutable_{false};
|
||||
};
|
||||
|
||||
/// \brief Class AbstractScalar describes a scalar's type and value.
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/variable.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (abs_sequence != nullptr) {
|
||||
const auto &elements = abs_sequence->elements();
|
||||
for (auto &ele : elements) {
|
||||
SetValueMutable(ele);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
|
||||
if (abs_dict != nullptr) {
|
||||
const auto &elements = abs_dict->elements();
|
||||
for (auto &ele : elements) {
|
||||
SetValueMutable(ele.second);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
abs->set_value_mutable(true);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
abstract::AbstractBasePtr Variable::ToAbstract() {
|
||||
if (real_value_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get abstract failed. The real value of Variable has not been set.";
|
||||
}
|
||||
auto abs = real_value_->ToAbstract();
|
||||
SetValueMutable(abs);
|
||||
return abs;
|
||||
}
|
||||
|
||||
bool Variable::operator==(const Variable &other) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
auto other_real_value = other.real_value();
|
||||
if (real_value_ == nullptr || other_real_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return *real_value_ == *other_real_value;
|
||||
}
|
||||
|
||||
std::string Variable::ToString() const {
|
||||
std::ostringstream oss;
|
||||
if (real_value_ == nullptr) {
|
||||
oss << "Variable(NULL)";
|
||||
} else {
|
||||
oss << "Variable(" << real_value_->ToString() << ")";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string Variable::DumpText() const {
|
||||
std::ostringstream oss;
|
||||
if (real_value_ == nullptr) {
|
||||
oss << type_name() << "(NULL)";
|
||||
} else {
|
||||
oss << type_name() << "(" << real_value_->DumpText() << ")";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_IR_VARIABLE_H_
|
||||
#define MINDSPORE_CORE_IR_VARIABLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MS_CORE_API Variable : public Value {
|
||||
public:
|
||||
explicit Variable(const ValuePtr &real_value) : real_value_(real_value) {}
|
||||
~Variable() override = default;
|
||||
MS_DECLARE_PARENT(Variable, Value)
|
||||
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
const ValuePtr &real_value() const { return real_value_; }
|
||||
|
||||
bool operator==(const Variable &other) const;
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<Variable>()) {
|
||||
auto other_variable = static_cast<const Variable &>(other);
|
||||
return *this == other_variable;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
std::string DumpText() const override;
|
||||
|
||||
private:
|
||||
ValuePtr real_value_{nullptr};
|
||||
};
|
||||
using VariablePtr = std::shared_ptr<Variable>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_VARIABLE_H_
|
|
@ -52,6 +52,14 @@ def TupleGetItem(x, index):
|
|||
x = x.asnumpy()
|
||||
y = x[index]
|
||||
return Tensor(y)
|
||||
|
||||
if isinstance(x, dict):
|
||||
count = 0
|
||||
for value in x.values():
|
||||
if count == index:
|
||||
return value
|
||||
count = count + 1
|
||||
|
||||
return x[index]
|
||||
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from .dump import set_dump
|
|||
from .parameter import Parameter, ParameterTuple
|
||||
from .seed import set_seed, get_seed
|
||||
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
|
||||
from .variable import Variable
|
||||
|
||||
# symbols from dtype
|
||||
__all__ = [
|
||||
|
@ -59,5 +60,6 @@ __all__.extend([
|
|||
"dtype", "_convert_data",
|
||||
"set_seed", "get_seed", # random seed
|
||||
"set_dump",
|
||||
"ms_memory_recycle"
|
||||
"ms_memory_recycle",
|
||||
"Variable"
|
||||
])
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Variable class for setting constants mutable."""
|
||||
|
||||
from .._c_expression import Variable_
|
||||
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||
|
||||
|
||||
class Variable(Variable_):
|
||||
"""
|
||||
Currently, all the inputs of Cell except Tensor such as scalar, tuple, list and dict, are regarded as constant
|
||||
values. The constant values are non-differentiable and used to do constant folding in the optimization process.
|
||||
We provide a class 'Variable' to store a constant value, to make the constant inputs of Cell 'mutable'.
|
||||
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
|
||||
thing is that it is differentiable from now on.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change or deletion.
|
||||
|
||||
Args:
|
||||
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.ops.composite import GradOperation
|
||||
>>> from mindspore.common.variable import Variable
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
...
|
||||
... def construct(self, x, y):
|
||||
... return x * y
|
||||
...
|
||||
>>> class GradNet(nn.Cell):
|
||||
... def __init__(self, net):
|
||||
... super(GradNet, self).__init__()
|
||||
... self.net = net
|
||||
... self.grad_op = GradOperation()
|
||||
...
|
||||
... def construct(self, x, y):
|
||||
... gradient_function = self.grad_op(self.net)
|
||||
... return gradient_function(x, y)
|
||||
...
|
||||
>>> x = Variable(2)
|
||||
>>> output = GradNet(Net())(x, 3)
|
||||
>>> print(output)
|
||||
3
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
|
||||
raise TypeError(
|
||||
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
|
||||
f"CSRTensor), but got {type(value).__name__}")
|
||||
Variable_.__init__(self, value)
|
||||
self._value = value
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
@value.setter
|
||||
def value(self, value):
|
||||
self._value = value
|
|
@ -33,6 +33,7 @@ from .._checkparam import Validator
|
|||
from ..common import dtype as mstype
|
||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
||||
from ..common.parameter import Parameter, ParameterTuple
|
||||
from ..common.variable import Variable
|
||||
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||
from ..ops.operations import Cast
|
||||
from ..ops.primitive import Primitive
|
||||
|
@ -967,10 +968,10 @@ class Cell(Cell_):
|
|||
if i.has_init:
|
||||
i.init_data()
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, COOTensor):
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, CSRTensor):
|
||||
elif isinstance(i, (COOTensor, CSRTensor)):
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, Variable):
|
||||
new_inputs.append(i.value)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||
|
|
|
@ -0,0 +1,497 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test getting gradient of Variable"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Parameter, Variable
|
||||
|
||||
|
||||
def compare(a, b):
|
||||
if isinstance(a, (list, tuple)):
|
||||
for aa, bb in zip(a, b):
|
||||
if not compare(aa, bb):
|
||||
return False
|
||||
return True
|
||||
|
||||
return np.allclose(a.asnumpy(), b)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_tuple_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to tuple tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0]
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_list_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to list tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0]
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_dict_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to dict tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t['a']
|
||||
y = t['b']
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_tuple_tuple_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested tuple tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0][0]
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable(((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_tuple_list_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested tuple and list tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0][0]
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable(([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_list_tuple_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested list and tuple tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0][0]
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable([(Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_tuple_dict_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested tuple and dict tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0]['a']
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_dict_tuple_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested dict and tuple tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t['a'][0]
|
||||
y = t['b']
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_list_dict_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested list and dict tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t[0]['a']
|
||||
y = t[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_variable_dict_list_tensor():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to nested dict and list tensor input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, t):
|
||||
x = t['a'][0]
|
||||
y = t['b']
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
class GradNetWrtX(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, z):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(z)
|
||||
|
||||
t = Variable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
|
||||
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test variable"""
|
||||
|
||||
import numpy as np
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.common.variable import Variable
|
||||
from mindspore.common.api import _CellGraphExecutor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import Parameter
|
||||
|
||||
|
||||
def test_variable_scalar_mul_grad_first():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to the first scalar input.
|
||||
Expectation: Get the correct gradient.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return x * y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
||||
x = Variable(2)
|
||||
output = GradNet(Net())(x, 3)
|
||||
assert output == 3
|
||||
|
||||
|
||||
def test_variable_scalar_mul_grad_all():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to all scalar inputs.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return x * y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation(get_all=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
||||
x = Variable(2)
|
||||
y = Variable(3)
|
||||
output = GradNet(Net())(x, y)
|
||||
assert output == (3, 2)
|
||||
|
||||
|
||||
def test_variable_tuple_or_list_scalar_mul_grad():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to the tuple or list scalar input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x[0] * x[1]
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, x):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x)
|
||||
|
||||
x = Variable((2, 3))
|
||||
output = GradNet(Net())(x)
|
||||
assert output == (3, 2)
|
||||
|
||||
x = Variable([2, 3])
|
||||
output = GradNet(Net())(x)
|
||||
assert output == (3, 2)
|
||||
|
||||
|
||||
def test_variable_dict_scalar_mul_grad():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to the dict scalar input.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x['a'] * x['b']
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation()
|
||||
|
||||
def construct(self, x):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x)
|
||||
|
||||
x = Variable({'a': 2, 'b': 3})
|
||||
output = GradNet(Net())(x)
|
||||
assert output == (3, 2)
|
||||
|
||||
|
||||
def test_variable_mix_scalar_mul_grad_all():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Get gradient with respect to the mix scalar input including dict and tuple.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return x['a'] * x['b'] * y[0]
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = GradOperation(get_all=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
||||
x = Variable({'a': 2, 'b': 3})
|
||||
y = Variable((4, 5))
|
||||
output = GradNet(Net())(x, y)
|
||||
assert output == ((12, 8), (6, 0))
|
||||
|
||||
|
||||
def test_tuple_inputs_compile_phase():
|
||||
"""
|
||||
Feature: Set Constants mutable.
|
||||
Description: Test whether the compilation phase for tuple input twice are the same.
|
||||
Expectation: The phases are the same.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, tuple_input):
|
||||
x = tuple_input[0]
|
||||
y = tuple_input[1]
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
net = Net()
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
phase1, _ = _cell_graph_executor.compile(net, (x, y))
|
||||
phase2, _ = _cell_graph_executor.compile(net, (p, q))
|
||||
assert phase1 != phase2
|
||||
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
|
||||
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
|
||||
assert phase1 == phase2
|
Loading…
Reference in New Issue