forked from mindspore-Ecosystem/mindspore
!32373 Add class Variable to set constant mutable
Merge pull request !32373 from YuJianfeng/mutable
This commit is contained in:
commit
b12fefbded
|
@ -802,6 +802,38 @@ void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
|
||||||
|
if (utils::isa<ValueSequencePtr>(arg)) {
|
||||||
|
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||||
|
auto sequence_value = value_sequence->value();
|
||||||
|
for (auto &value : sequence_value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
(void)flatted_value->emplace_back(value);
|
||||||
|
} else {
|
||||||
|
FlattenValue(value, flatted_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
|
||||||
|
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_dict);
|
||||||
|
auto dict_value = value_dict->value();
|
||||||
|
for (auto &iter : dict_value) {
|
||||||
|
auto value = iter.second;
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
(void)flatted_value->emplace_back(value);
|
||||||
|
} else {
|
||||||
|
FlattenValue(value, flatted_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
|
||||||
|
<< arg.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
||||||
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
|
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
|
||||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
||||||
|
@ -814,11 +846,8 @@ void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶m
|
||||||
(void)input_tensor->emplace_back(nullptr);
|
(void)input_tensor->emplace_back(nullptr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto value_tuple = utils::cast<ValueTuplePtr>(args[position]);
|
|
||||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
||||||
auto value_tuple_value = value_tuple->value();
|
|
||||||
ValuePtrList flatted_value_tuple_value;
|
ValuePtrList flatted_value_tuple_value;
|
||||||
FlatValueTupleValue(value_tuple_value, &flatted_value_tuple_value);
|
FlattenValue(args[position], &flatted_value_tuple_value);
|
||||||
if (index >= flatted_value_tuple_value.size()) {
|
if (index >= flatted_value_tuple_value.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
|
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
|
||||||
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
|
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
|
||||||
|
|
|
@ -358,6 +358,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
||||||
.def(py::init<bool>(), py::arg("reverse"));
|
.def(py::init<bool>(), py::arg("reverse"));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
namespace {
|
||||||
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
||||||
MS_EXCEPTION_IF_NULL(tuple);
|
MS_EXCEPTION_IF_NULL(tuple);
|
||||||
for (size_t i = 0; i < tuple->size(); ++i) {
|
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||||
|
@ -370,17 +371,23 @@ bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
|
||||||
|
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||||
|
abs->BuildType()->isa<Number>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
|
||||||
|
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
||||||
|
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
|
||||||
|
}
|
||||||
|
|
||||||
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
|
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
|
||||||
MS_EXCEPTION_IF_NULL(sequeue);
|
MS_EXCEPTION_IF_NULL(sequeue);
|
||||||
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
||||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
|
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
|
EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
|
||||||
(*sequeue)[1]->BuildType()->isa<Number>()) ||
|
|
||||||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
|
||||||
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
|
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
|
||||||
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
|
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
|
||||||
if (pos == nullptr) {
|
if (pos == nullptr) {
|
||||||
|
@ -470,9 +477,8 @@ FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr
|
||||||
for (size_t i = 1; i < sequeue->size(); ++i) {
|
for (size_t i = 1; i < sequeue->size(); ++i) {
|
||||||
if (tail_type_ == kGradAll) {
|
if (tail_type_ == kGradAll) {
|
||||||
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
||||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
|
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
|
EnableGradForScalar((*sequeue)[i])) {
|
||||||
(*sequeue)[i]->BuildType()->isa<Number>())) {
|
|
||||||
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -43,6 +43,7 @@ using mindspore::abstract::AbstractBasePtr;
|
||||||
using mindspore::abstract::AbstractClass;
|
using mindspore::abstract::AbstractClass;
|
||||||
using mindspore::abstract::AbstractCOOTensor;
|
using mindspore::abstract::AbstractCOOTensor;
|
||||||
using mindspore::abstract::AbstractDictionary;
|
using mindspore::abstract::AbstractDictionary;
|
||||||
|
using mindspore::abstract::AbstractDictionaryPtr;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractListPtr;
|
using mindspore::abstract::AbstractListPtr;
|
||||||
using mindspore::abstract::AbstractRowTensor;
|
using mindspore::abstract::AbstractRowTensor;
|
||||||
|
@ -53,6 +54,7 @@ using mindspore::abstract::AbstractTuple;
|
||||||
using mindspore::abstract::AbstractTuplePtr;
|
using mindspore::abstract::AbstractTuplePtr;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
static constexpr size_t kMaxSeqRecursiveDepth = 5;
|
||||||
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
|
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
|
||||||
if (cnode->size() != expect_size) {
|
if (cnode->size() != expect_size) {
|
||||||
std::string op_name = GetCNodeFuncName(cnode);
|
std::string op_name = GetCNodeFuncName(cnode);
|
||||||
|
@ -450,11 +452,60 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
||||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
// AbstractDictionary, AbstractClass --> AbstractSequence.
|
||||||
|
static AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
|
||||||
|
if (depth > kMaxSeqRecursiveDepth) {
|
||||||
|
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||||
|
}
|
||||||
|
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||||
|
if (abs_seq != nullptr) {
|
||||||
|
const auto &seq_elements = abs_seq->elements();
|
||||||
|
// First we check if elements should be converted,
|
||||||
|
// changed_elements maps old element to new element.
|
||||||
|
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||||
|
for (const auto &element : seq_elements) {
|
||||||
|
auto new_element = ConvertToAbstractSequence(element, depth + 1);
|
||||||
|
if (new_element != nullptr) {
|
||||||
|
(void)changed_elements.emplace(element, new_element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed_elements.empty()) {
|
||||||
|
// Here the AbstractList don't need to convert to AbstractTuple.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Always make new AbstractSequence when elements changed.
|
||||||
|
std::vector<AbstractBasePtr> elements;
|
||||||
|
elements.reserve(seq_elements.size());
|
||||||
|
for (const auto &element : seq_elements) {
|
||||||
|
auto iter = changed_elements.find(element);
|
||||||
|
if (iter != changed_elements.end()) {
|
||||||
|
(void)elements.emplace_back(iter->second);
|
||||||
|
} else {
|
||||||
|
(void)elements.emplace_back(element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Here the AbstractList don't need to convert to AbstractTuple.
|
||||||
|
if (abs_seq->isa<AbstractList>()) {
|
||||||
|
return std::make_shared<AbstractList>(std::move(elements));
|
||||||
|
} else {
|
||||||
|
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||||
|
}
|
||||||
|
}
|
||||||
// AbstractDictionary --> AbstractTuple.
|
// AbstractDictionary --> AbstractTuple.
|
||||||
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
|
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
|
||||||
if (abs_dict != nullptr) {
|
if (abs_dict != nullptr) {
|
||||||
return MakeAbstractTuple(abs_dict->elements());
|
const auto &dict_elements = abs_dict->elements();
|
||||||
|
std::vector<AbstractBasePtr> elements;
|
||||||
|
elements.reserve(dict_elements.size());
|
||||||
|
for (const auto &element : dict_elements) {
|
||||||
|
auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
|
||||||
|
if (new_element != nullptr) {
|
||||||
|
(void)elements.emplace_back(new_element);
|
||||||
|
} else {
|
||||||
|
(void)elements.emplace_back(element.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractTuple>(elements);
|
||||||
}
|
}
|
||||||
// AbstractClass --> AbstractTuple.
|
// AbstractClass --> AbstractTuple.
|
||||||
auto abs_class = abs->cast<abstract::AbstractClassPtr>();
|
auto abs_class = abs->cast<abstract::AbstractClassPtr>();
|
||||||
|
@ -463,6 +514,11 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||||
|
// AbstractDictionary, AbstractClass --> AbstractSequence.
|
||||||
|
return ConvertToAbstractSequence(abs, 0);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ==================================================================
|
// ==================================================================
|
||||||
|
@ -597,12 +653,10 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
||||||
return (this->*(iter->second))(cnode);
|
return (this->*(iter->second))(cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr size_t kMaxListRecursiveDepth = 5;
|
|
||||||
|
|
||||||
// ValueList --> ValueTuple
|
// ValueList --> ValueTuple
|
||||||
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
|
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
|
||||||
if (depth > kMaxListRecursiveDepth) {
|
if (depth > kMaxSeqRecursiveDepth) {
|
||||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||||
}
|
}
|
||||||
const auto &list_elements = value_list->value();
|
const auto &list_elements = value_list->value();
|
||||||
std::vector<ValuePtr> elements;
|
std::vector<ValuePtr> elements;
|
||||||
|
@ -625,50 +679,60 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// AbstractSequence --> AbstractTuple
|
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
|
||||||
static AbstractTuplePtr ConvertAbstractSeqToAbstractTuple(const AbstractSequencePtr &abs_seq, size_t depth) {
|
static AbstractTuplePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
|
||||||
if (depth > kMaxListRecursiveDepth) {
|
if (depth > kMaxSeqRecursiveDepth) {
|
||||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||||
}
|
}
|
||||||
const auto &seq_elements = abs_seq->elements();
|
// AbstractList --> AbstractTuple.
|
||||||
// First we check if elements should be converted,
|
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||||
// changed_elements maps old element to new element.
|
if (abs_seq != nullptr) {
|
||||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
const auto &seq_elements = abs_seq->elements();
|
||||||
for (const auto &element : seq_elements) {
|
// First we check if elements should be converted,
|
||||||
if (element->isa<AbstractSequence>()) {
|
// changed_elements maps old element to new element.
|
||||||
auto new_element = ConvertAbstractSeqToAbstractTuple(element->cast<AbstractSequencePtr>(), depth + 1);
|
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||||
|
for (const auto &element : seq_elements) {
|
||||||
|
auto new_element = ConvertToAbstractTuple(element, depth + 1);
|
||||||
if (new_element != nullptr) {
|
if (new_element != nullptr) {
|
||||||
(void)changed_elements.emplace(element, new_element);
|
(void)changed_elements.emplace(element, new_element);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if (changed_elements.empty()) {
|
||||||
if (changed_elements.empty()) {
|
if (abs->isa<AbstractTuple>()) {
|
||||||
if (abs_seq->isa<AbstractTuple>()) {
|
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
return nullptr;
|
||||||
return nullptr;
|
}
|
||||||
|
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||||
|
return std::make_shared<AbstractTuple>(seq_elements);
|
||||||
}
|
}
|
||||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
// Always make new AbstractTuple when elements changed.
|
||||||
return std::make_shared<AbstractTuple>(seq_elements);
|
std::vector<AbstractBasePtr> elements;
|
||||||
}
|
elements.reserve(seq_elements.size());
|
||||||
// Always make new AbstractTuple when elements changed.
|
for (const auto &element : seq_elements) {
|
||||||
std::vector<AbstractBasePtr> elements;
|
auto iter = changed_elements.find(element);
|
||||||
elements.reserve(seq_elements.size());
|
if (iter != changed_elements.end()) {
|
||||||
for (const auto &element : seq_elements) {
|
(void)elements.emplace_back(iter->second);
|
||||||
auto iter = changed_elements.find(element);
|
} else {
|
||||||
if (iter != changed_elements.end()) {
|
(void)elements.emplace_back(element);
|
||||||
(void)elements.emplace_back(iter->second);
|
}
|
||||||
} else {
|
|
||||||
(void)elements.emplace_back(element);
|
|
||||||
}
|
}
|
||||||
|
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||||
}
|
}
|
||||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
// AbstractDict --> AbstractTuple.
|
||||||
}
|
auto abs_dict = abs->cast<AbstractDictionaryPtr>();
|
||||||
|
if (abs_dict != nullptr) {
|
||||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
const auto &dict_elements = abs_dict->elements();
|
||||||
// AbstractSequence --> AbstractTuple.
|
std::vector<AbstractBasePtr> elements;
|
||||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
elements.reserve(dict_elements.size());
|
||||||
if (abs_seq != nullptr) {
|
for (const auto &element : dict_elements) {
|
||||||
return ConvertAbstractSeqToAbstractTuple(abs_seq, 0);
|
auto new_element = ConvertToAbstractTuple(element.second, depth + 1);
|
||||||
|
if (new_element != nullptr) {
|
||||||
|
(void)elements.emplace_back(new_element);
|
||||||
|
} else {
|
||||||
|
(void)elements.emplace_back(element.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractTuple>(elements);
|
||||||
}
|
}
|
||||||
// AbstractCOOTensor --> AbstractTuple.
|
// AbstractCOOTensor --> AbstractTuple.
|
||||||
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
|
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
|
||||||
|
@ -685,6 +749,11 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||||
|
// AbstractSequence, AbstractDict, AbstractCOOTensor, AbstractRowTensor --> AbstractTuple.
|
||||||
|
return ConvertToAbstractTuple(abs, 0);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,15 @@ namespace pipeline {
|
||||||
namespace {
|
namespace {
|
||||||
bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); }
|
bool ExistControlFlow(const FuncGraphPtr &func_graph) { return !func_graph->func_graphs_used_total().empty(); }
|
||||||
|
|
||||||
|
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
|
||||||
|
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||||
|
abs->BuildType()->isa<Number>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnableTupleBroaden(const abstract::AbstractBasePtr &abs) {
|
||||||
|
return abs->isa<abstract::AbstractTuple>() && abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors();
|
||||||
|
}
|
||||||
|
|
||||||
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
std::vector<AnfNodePtr> new_paras;
|
std::vector<AnfNodePtr> new_paras;
|
||||||
|
@ -78,11 +87,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
AbstractBasePtr par_abs = param_node->abstract();
|
AbstractBasePtr par_abs = param_node->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(par_abs);
|
MS_EXCEPTION_IF_NULL(par_abs);
|
||||||
if (par_abs->isa<abstract::AbstractUndetermined>() ||
|
if (par_abs->isa<abstract::AbstractUndetermined>() || par_abs->BuildValue() == kAnyValue ||
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
|
EnableGradForScalar(par_abs) || EnableTupleBroaden(par_abs)) {
|
||||||
par_abs->BuildType()->isa<Number>()) ||
|
|
||||||
(par_abs->isa<abstract::AbstractTuple>() &&
|
|
||||||
par_abs->cast<abstract::AbstractTuplePtr>()->ContainsAllBroadenTensors())) {
|
|
||||||
new_paras.push_back(param_node);
|
new_paras.push_back(param_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
|
#include "ir/variable.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parse {
|
namespace parse {
|
||||||
|
@ -506,6 +507,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
|
||||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||||
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||||
|
std::make_shared<ByTypeDataConverter<Variable>>(ObjCast<VariablePtr>),
|
||||||
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
||||||
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
||||||
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "pybind_api/pybind_patch.h"
|
#include "pybind_api/pybind_patch.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
|
#include "ir/variable.h"
|
||||||
#include "pipeline/jit/pass.h"
|
#include "pipeline/jit/pass.h"
|
||||||
#include "pipeline/jit/parse/data_converter.h"
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||||
|
@ -149,7 +150,7 @@ bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
|
||||||
|
|
||||||
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) {
|
AbstractBasePtr ArgsToAbstract(const ValuePtr &value, bool enable_tuple_broaden = false) {
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
bool broaden = value->isa<MetaTensor>() ||
|
bool broaden = value->isa<MetaTensor>() || value->isa<Variable>() ||
|
||||||
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
|
(enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>())) ||
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
||||||
|
|
||||||
|
@ -167,6 +168,14 @@ bool CheckArgValid(const py::handle &arg) {
|
||||||
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (py::isinstance<Variable>(arg)) {
|
||||||
|
if (py::hasattr(arg, "value")) {
|
||||||
|
return CheckArgValid(arg.attr("value"));
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << "There should be a python object value stored in the Variable " << py::str(arg);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (py::isinstance<Tensor>(arg)) {
|
if (py::isinstance<Tensor>(arg)) {
|
||||||
auto tensor = py::cast<TensorPtr>(arg);
|
auto tensor = py::cast<TensorPtr>(arg);
|
||||||
if (tensor->data_type() == kNumberTypeBool) {
|
if (tensor->data_type() == kNumberTypeBool) {
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "include/common/pybind_api/api_register.h"
|
||||||
|
#include "ir/variable.h"
|
||||||
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
namespace mindspore {
|
||||||
|
REGISTER_PYBIND_DEFINE(Variable_, ([](const py::module *m) {
|
||||||
|
(void)py::class_<Variable, VariablePtr>(*m, "Variable_")
|
||||||
|
.def(py::init([](const py::object &py_value) {
|
||||||
|
ValuePtr real_value = nullptr;
|
||||||
|
if (!parse::ConvertData(py_value, &real_value)) {
|
||||||
|
MS_EXCEPTION(TypeError)
|
||||||
|
<< "Convert python object failed, the object type is " << py_value.get_type()
|
||||||
|
<< ", value is '" << py::str(py_value) << "'.";
|
||||||
|
}
|
||||||
|
return std::make_shared<Variable>(real_value);
|
||||||
|
}),
|
||||||
|
py::arg("py_value"));
|
||||||
|
}));
|
||||||
|
} // namespace mindspore
|
|
@ -158,7 +158,7 @@ std::string AbstractBase::ToString(bool verbose) const {
|
||||||
AbstractBasePtr AbstractScalar::Broaden() const {
|
AbstractBasePtr AbstractScalar::Broaden() const {
|
||||||
auto context = MsContext::GetInstance();
|
auto context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || value_mutable()) {
|
||||||
return AbstractBase::Broaden();
|
return AbstractBase::Broaden();
|
||||||
}
|
}
|
||||||
auto type_id = GetTypeTrack()->type_id();
|
auto type_id = GetTypeTrack()->type_id();
|
||||||
|
|
|
@ -187,6 +187,10 @@ class MS_CORE_API AbstractBase : public Base {
|
||||||
/// \return A pointer to the broadened abstract.
|
/// \return A pointer to the broadened abstract.
|
||||||
virtual AbstractBasePtr PartialBroaden() const;
|
virtual AbstractBasePtr PartialBroaden() const;
|
||||||
|
|
||||||
|
bool value_mutable() const { return value_mutable_; }
|
||||||
|
|
||||||
|
void set_value_mutable(bool value_mutable) { value_mutable_ = value_mutable; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \brief Build a value when value is not set.
|
/// \brief Build a value when value is not set.
|
||||||
///
|
///
|
||||||
|
@ -198,6 +202,7 @@ class MS_CORE_API AbstractBase : public Base {
|
||||||
TypePtr type_;
|
TypePtr type_;
|
||||||
BaseShapePtr shape_;
|
BaseShapePtr shape_;
|
||||||
std::string value_desc_; // store initial value description for error report
|
std::string value_desc_; // store initial value description for error report
|
||||||
|
bool value_mutable_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Class AbstractScalar describes a scalar's type and value.
|
/// \brief Class AbstractScalar describes a scalar's type and value.
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ir/variable.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace {
|
||||||
|
void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
if (abs->isa<abstract::AbstractTensor>()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
|
||||||
|
if (abs_sequence != nullptr) {
|
||||||
|
const auto &elements = abs_sequence->elements();
|
||||||
|
for (auto &ele : elements) {
|
||||||
|
SetValueMutable(ele);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
|
||||||
|
if (abs_dict != nullptr) {
|
||||||
|
const auto &elements = abs_dict->elements();
|
||||||
|
for (auto &ele : elements) {
|
||||||
|
SetValueMutable(ele.second);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
abs->set_value_mutable(true);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr Variable::ToAbstract() {
|
||||||
|
if (real_value_ == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Get abstract failed. The real value of Variable has not been set.";
|
||||||
|
}
|
||||||
|
auto abs = real_value_->ToAbstract();
|
||||||
|
SetValueMutable(abs);
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Variable::operator==(const Variable &other) const {
|
||||||
|
if (this == &other) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto other_real_value = other.real_value();
|
||||||
|
if (real_value_ == nullptr || other_real_value == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *real_value_ == *other_real_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Variable::ToString() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
if (real_value_ == nullptr) {
|
||||||
|
oss << "Variable(NULL)";
|
||||||
|
} else {
|
||||||
|
oss << "Variable(" << real_value_->ToString() << ")";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Variable::DumpText() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
if (real_value_ == nullptr) {
|
||||||
|
oss << type_name() << "(NULL)";
|
||||||
|
} else {
|
||||||
|
oss << type_name() << "(" << real_value_->DumpText() << ")";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_IR_VARIABLE_H_
|
||||||
|
#define MINDSPORE_CORE_IR_VARIABLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class MS_CORE_API Variable : public Value {
|
||||||
|
public:
|
||||||
|
explicit Variable(const ValuePtr &real_value) : real_value_(real_value) {}
|
||||||
|
~Variable() override = default;
|
||||||
|
MS_DECLARE_PARENT(Variable, Value)
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
|
||||||
|
const ValuePtr &real_value() const { return real_value_; }
|
||||||
|
|
||||||
|
bool operator==(const Variable &other) const;
|
||||||
|
|
||||||
|
bool operator==(const Value &other) const override {
|
||||||
|
if (other.isa<Variable>()) {
|
||||||
|
auto other_variable = static_cast<const Variable &>(other);
|
||||||
|
return *this == other_variable;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
std::string DumpText() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ValuePtr real_value_{nullptr};
|
||||||
|
};
|
||||||
|
using VariablePtr = std::shared_ptr<Variable>;
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_IR_VARIABLE_H_
|
|
@ -52,6 +52,14 @@ def TupleGetItem(x, index):
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
y = x[index]
|
y = x[index]
|
||||||
return Tensor(y)
|
return Tensor(y)
|
||||||
|
|
||||||
|
if isinstance(x, dict):
|
||||||
|
count = 0
|
||||||
|
for value in x.values():
|
||||||
|
if count == index:
|
||||||
|
return value
|
||||||
|
count = count + 1
|
||||||
|
|
||||||
return x[index]
|
return x[index]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ from .dump import set_dump
|
||||||
from .parameter import Parameter, ParameterTuple
|
from .parameter import Parameter, ParameterTuple
|
||||||
from .seed import set_seed, get_seed
|
from .seed import set_seed, get_seed
|
||||||
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
|
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
|
||||||
|
from .variable import Variable
|
||||||
|
|
||||||
# symbols from dtype
|
# symbols from dtype
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -59,5 +60,6 @@ __all__.extend([
|
||||||
"dtype", "_convert_data",
|
"dtype", "_convert_data",
|
||||||
"set_seed", "get_seed", # random seed
|
"set_seed", "get_seed", # random seed
|
||||||
"set_dump",
|
"set_dump",
|
||||||
"ms_memory_recycle"
|
"ms_memory_recycle",
|
||||||
|
"Variable"
|
||||||
])
|
])
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Variable class for setting constants mutable."""
|
||||||
|
|
||||||
|
from .._c_expression import Variable_
|
||||||
|
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||||
|
|
||||||
|
|
||||||
|
class Variable(Variable_):
|
||||||
|
"""
|
||||||
|
Currently, all the inputs of Cell except Tensor such as scalar, tuple, list and dict, are regarded as constant
|
||||||
|
values. The constant values are non-differentiable and used to do constant folding in the optimization process.
|
||||||
|
We provide a class 'Variable' to store a constant value, to make the constant inputs of Cell 'mutable'.
|
||||||
|
A 'mutable' constant input means that it is changed to be a variable input just like Tensor and the most important
|
||||||
|
thing is that it is differentiable from now on.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change or deletion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (Union[bool, float, int, tuple, list, dict, Tensor]): The value to be stored.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore.ops.composite import GradOperation
|
||||||
|
>>> from mindspore.common.variable import Variable
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
... def __init__(self):
|
||||||
|
... super(Net, self).__init__()
|
||||||
|
...
|
||||||
|
... def construct(self, x, y):
|
||||||
|
... return x * y
|
||||||
|
...
|
||||||
|
>>> class GradNet(nn.Cell):
|
||||||
|
... def __init__(self, net):
|
||||||
|
... super(GradNet, self).__init__()
|
||||||
|
... self.net = net
|
||||||
|
... self.grad_op = GradOperation()
|
||||||
|
...
|
||||||
|
... def construct(self, x, y):
|
||||||
|
... gradient_function = self.grad_op(self.net)
|
||||||
|
... return gradient_function(x, y)
|
||||||
|
...
|
||||||
|
>>> x = Variable(2)
|
||||||
|
>>> output = GradNet(Net())(x, 3)
|
||||||
|
>>> print(output)
|
||||||
|
3
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
if not isinstance(value, (bool, int, float, tuple, list, dict, Tensor, COOTensor, CSRTensor)):
|
||||||
|
raise TypeError(
|
||||||
|
f"For 'Varibale', the 'value' should be one of (int, float, tuple, list, dict, Tensor, COOTensor, "
|
||||||
|
f"CSRTensor), but got {type(value).__name__}")
|
||||||
|
Variable_.__init__(self, value)
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
@value.setter
|
||||||
|
def value(self, value):
|
||||||
|
self._value = value
|
|
@ -33,6 +33,7 @@ from .._checkparam import Validator
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
||||||
from ..common.parameter import Parameter, ParameterTuple
|
from ..common.parameter import Parameter, ParameterTuple
|
||||||
|
from ..common.variable import Variable
|
||||||
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||||
from ..ops.operations import Cast
|
from ..ops.operations import Cast
|
||||||
from ..ops.primitive import Primitive
|
from ..ops.primitive import Primitive
|
||||||
|
@ -967,10 +968,10 @@ class Cell(Cell_):
|
||||||
if i.has_init:
|
if i.has_init:
|
||||||
i.init_data()
|
i.init_data()
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
elif isinstance(i, COOTensor):
|
elif isinstance(i, (COOTensor, CSRTensor)):
|
||||||
new_inputs.append(i)
|
|
||||||
elif isinstance(i, CSRTensor):
|
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
|
elif isinstance(i, Variable):
|
||||||
|
new_inputs.append(i.value)
|
||||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
|
elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||||
|
|
|
@ -0,0 +1,497 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test getting gradient of Variable"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import Parameter, Variable
|
||||||
|
|
||||||
|
|
||||||
|
def compare(a, b):
|
||||||
|
if isinstance(a, (list, tuple)):
|
||||||
|
for aa, bb in zip(a, b):
|
||||||
|
if not compare(aa, bb):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
return np.allclose(a.asnumpy(), b)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_tuple_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to tuple tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0]
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_list_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to list tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0]
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_dict_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to dict tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t['a']
|
||||||
|
y = t['b']
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_tuple_tuple_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested tuple tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0][0]
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable(((Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_tuple_list_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested tuple and list tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0][0]
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable(([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_list_tuple_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested list and tuple tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0][0]
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable([(Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_tuple_dict_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested tuple and dict tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0]['a']
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_dict_tuple_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested dict and tuple tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t['a'][0]
|
||||||
|
y = t['b']
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||||
|
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_list_dict_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested list and dict tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t[0]['a']
|
||||||
|
y = t[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
|
||||||
|
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_grad_variable_dict_list_tensor():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to nested dict and list tensor input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, t):
|
||||||
|
x = t['a'][0]
|
||||||
|
y = t['b']
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class GradNetWrtX(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNetWrtX, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, z):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(z)
|
||||||
|
|
||||||
|
t = Variable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||||
|
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
|
||||||
|
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
|
||||||
|
output = GradNetWrtX(Net())(t)
|
||||||
|
assert isinstance(output, tuple)
|
||||||
|
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||||
|
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||||
|
[0, 0, 0]]).astype(np.float32)],
|
||||||
|
np.array([[1.7, 1.7, 1.7],
|
||||||
|
[1.9, 1.9, 1.9],
|
||||||
|
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||||
|
assert compare(output, expect)
|
|
@ -0,0 +1,195 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test variable"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.ops.composite import GradOperation
|
||||||
|
from mindspore.common.variable import Variable
|
||||||
|
from mindspore.common.api import _CellGraphExecutor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_scalar_mul_grad_first():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to the first scalar input.
|
||||||
|
Expectation: Get the correct gradient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x, y)
|
||||||
|
|
||||||
|
x = Variable(2)
|
||||||
|
output = GradNet(Net())(x, 3)
|
||||||
|
assert output == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_scalar_mul_grad_all():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to all scalar inputs.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation(get_all=True)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x, y)
|
||||||
|
|
||||||
|
x = Variable(2)
|
||||||
|
y = Variable(3)
|
||||||
|
output = GradNet(Net())(x, y)
|
||||||
|
assert output == (3, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_tuple_or_list_scalar_mul_grad():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to the tuple or list scalar input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x[0] * x[1]
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x)
|
||||||
|
|
||||||
|
x = Variable((2, 3))
|
||||||
|
output = GradNet(Net())(x)
|
||||||
|
assert output == (3, 2)
|
||||||
|
|
||||||
|
x = Variable([2, 3])
|
||||||
|
output = GradNet(Net())(x)
|
||||||
|
assert output == (3, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_dict_scalar_mul_grad():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to the dict scalar input.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x['a'] * x['b']
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x)
|
||||||
|
|
||||||
|
x = Variable({'a': 2, 'b': 3})
|
||||||
|
output = GradNet(Net())(x)
|
||||||
|
assert output == (3, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_mix_scalar_mul_grad_all():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Get gradient with respect to the mix scalar input including dict and tuple.
|
||||||
|
Expectation: Get the correct gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x['a'] * x['b'] * y[0]
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = GradOperation(get_all=True)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
gradient_function = self.grad_op(self.net)
|
||||||
|
return gradient_function(x, y)
|
||||||
|
|
||||||
|
x = Variable({'a': 2, 'b': 3})
|
||||||
|
y = Variable((4, 5))
|
||||||
|
output = GradNet(Net())(x, y)
|
||||||
|
assert output == ((12, 8), (6, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_inputs_compile_phase():
|
||||||
|
"""
|
||||||
|
Feature: Set Constants mutable.
|
||||||
|
Description: Test whether the compilation phase for tuple input twice are the same.
|
||||||
|
Expectation: The phases are the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||||
|
|
||||||
|
def construct(self, tuple_input):
|
||||||
|
x = tuple_input[0]
|
||||||
|
y = tuple_input[1]
|
||||||
|
x = x * self.z
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
net = Net()
|
||||||
|
_cell_graph_executor = _CellGraphExecutor()
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, (x, y))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, (p, q))
|
||||||
|
assert phase1 != phase2
|
||||||
|
phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
|
||||||
|
phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
|
||||||
|
assert phase1 == phase2
|
Loading…
Reference in New Issue