forked from mindspore-Ecosystem/mindspore
!62 support grad on python function with variable arguments
Merge pull request !62 from amongo/SupportGradOnVarArgs
This commit is contained in:
commit
cf54ecfe6e
|
@ -1199,51 +1199,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
|
||||||
// slice a tensor
|
|
||||||
// args: tensor, slice or slice tuple
|
|
||||||
const std::string op_name = std::string("UnpackCall");
|
|
||||||
size_t arg_length = args_spec_list.size();
|
|
||||||
if (arg_length < 2) {
|
|
||||||
MS_LOG(EXCEPTION) << "" << op_name << " requires at least two args, but got " << arg_length << ".";
|
|
||||||
}
|
|
||||||
|
|
||||||
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
|
||||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
|
||||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
|
||||||
|
|
||||||
AnfNodePtr fnNode = ret_graph->add_parameter();
|
|
||||||
std::vector<AnfNodePtr> elems;
|
|
||||||
elems.push_back(fnNode);
|
|
||||||
for (size_t index = 1; index < arg_length; index++) {
|
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
|
||||||
if (args_spec_list[index]->isa<AbstractTuple>()) {
|
|
||||||
AbstractTuplePtr arg_tuple = dyn_cast<AbstractTuple>(args_spec_list[index]);
|
|
||||||
AnfNodePtr para_tuple = ret_graph->add_parameter();
|
|
||||||
for (size_t i = 0; i < arg_tuple->size(); ++i) {
|
|
||||||
elems.push_back(
|
|
||||||
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
|
|
||||||
}
|
|
||||||
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
|
|
||||||
AbstractDictionaryPtr arg_dict = dyn_cast<AbstractDictionary>(args_spec_list[index]);
|
|
||||||
AnfNodePtr para_dict = ret_graph->add_parameter();
|
|
||||||
auto dict_elems = arg_dict->elements();
|
|
||||||
(void)std::transform(
|
|
||||||
dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
|
|
||||||
[ret_graph, para_dict](const AbstractAttribute& item) {
|
|
||||||
return ret_graph->NewCNode(
|
|
||||||
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first),
|
|
||||||
ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)})});
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "" << op_name << " require args should be tuple or dict, but got "
|
|
||||||
<< args_spec_list[index]->ToString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ret_graph->set_output(ret_graph->NewCNode(elems));
|
|
||||||
return ret_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
TupleAdd_, ([](const py::module* m) {
|
TupleAdd_, ([](const py::module* m) {
|
||||||
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>());
|
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>());
|
||||||
|
@ -1258,10 +1213,5 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) {
|
||||||
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
||||||
.def(py::init<std::string&>());
|
.def(py::init<std::string&>());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
|
|
||||||
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
|
|
||||||
.def(py::init<std::string&>());
|
|
||||||
}));
|
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "operator/composite/zip_operation.h"
|
#include "operator/composite/zip_operation.h"
|
||||||
#include "operator/composite/list_append_operation.h"
|
#include "operator/composite/list_append_operation.h"
|
||||||
#include "operator/composite/do_signature.h"
|
#include "operator/composite/do_signature.h"
|
||||||
|
#include "operator/composite/unpack_call.h"
|
||||||
#include "pipeline/static_analysis/static_analysis.h"
|
#include "pipeline/static_analysis/static_analysis.h"
|
||||||
#include "utils/misc.h"
|
#include "utils/misc.h"
|
||||||
#include "utils/any.h"
|
#include "utils/any.h"
|
||||||
|
@ -154,7 +155,7 @@ class GradOperation : public MetaFuncGraph {
|
||||||
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams,
|
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams,
|
||||||
bool applyJ = false);
|
bool applyJ = false);
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
||||||
|
bool sens_param() const { return sens_param_; }
|
||||||
bool get_all_;
|
bool get_all_;
|
||||||
bool get_by_list_;
|
bool get_by_list_;
|
||||||
bool sens_param_;
|
bool sens_param_;
|
||||||
|
@ -208,17 +209,6 @@ class TensorSlice : public MetaFuncGraph {
|
||||||
};
|
};
|
||||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||||
|
|
||||||
// Expand the tuple and dict parameters generated when parsing the function call,
|
|
||||||
// and generate positional parameters and key-value pairs for function.
|
|
||||||
class UnpackCall : public MetaFuncGraph {
|
|
||||||
public:
|
|
||||||
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
|
|
||||||
~UnpackCall() override = default;
|
|
||||||
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
|
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
|
||||||
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
|
|
||||||
};
|
|
||||||
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "operator/composite/unpack_call.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "./common.h"
|
||||||
|
#include "pipeline/static_analysis/abstract_value.h"
|
||||||
|
#include "pipeline/static_analysis/dshape.h"
|
||||||
|
#include "pipeline/static_analysis/param_validator.h"
|
||||||
|
#include "operator/cc_implementations.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "optimizer/opt.h"
|
||||||
|
#include "utils/symbolic.h"
|
||||||
|
#include "pybind_api/api_register.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
// namespace to support composite operators definition
|
||||||
|
namespace prim {
|
||||||
|
using mindspore::abstract::AbstractAttribute;
|
||||||
|
using mindspore::abstract::AbstractBase;
|
||||||
|
using mindspore::abstract::AbstractDictionary;
|
||||||
|
using mindspore::abstract::AbstractDictionaryPtr;
|
||||||
|
using mindspore::abstract::AbstractFunction;
|
||||||
|
using mindspore::abstract::AbstractKeywordArg;
|
||||||
|
using mindspore::abstract::AbstractTuple;
|
||||||
|
using mindspore::abstract::AbstractTuplePtr;
|
||||||
|
|
||||||
|
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
||||||
|
// slice a tensor
|
||||||
|
// args: tensor, slice or slice tuple
|
||||||
|
const std::string op_name = std::string("UnpackCall");
|
||||||
|
size_t arg_length = args_spec_list.size();
|
||||||
|
if (arg_length < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||||
|
auto ret_graph = std::make_shared<FuncGraph>();
|
||||||
|
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
|
||||||
|
AnfNodePtr fnNode = ret_graph->add_parameter();
|
||||||
|
std::vector<AnfNodePtr> elems;
|
||||||
|
elems.push_back(fnNode);
|
||||||
|
for (size_t index = 1; index < arg_length; index++) {
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||||
|
if (args_spec_list[index]->isa<AbstractTuple>()) {
|
||||||
|
auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
|
||||||
|
AnfNodePtr para_tuple = ret_graph->add_parameter();
|
||||||
|
for (size_t i = 0; i < arg_tuple->size(); ++i) {
|
||||||
|
elems.push_back(
|
||||||
|
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
|
||||||
|
}
|
||||||
|
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
|
||||||
|
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
|
||||||
|
AnfNodePtr para_dict = ret_graph->add_parameter();
|
||||||
|
auto dict_elems = arg_dict->elements();
|
||||||
|
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
|
||||||
|
[ret_graph, para_dict](const AbstractAttribute& item) {
|
||||||
|
auto dict_get_item = ret_graph->NewCNode(
|
||||||
|
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
|
||||||
|
return ret_graph->NewCNode(
|
||||||
|
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got "
|
||||||
|
<< args_spec_list[index]->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret_graph->set_output(ret_graph->NewCNode(elems));
|
||||||
|
return ret_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
|
||||||
|
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
|
||||||
|
.def(py::init<std::string&>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
} // namespace prim
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "pipeline/static_analysis/static_analysis.h"
|
||||||
|
#include "utils/misc.h"
|
||||||
|
#include "utils/any.h"
|
||||||
|
#include "ir/dtype.h"
|
||||||
|
#include "ir/meta_func_graph.h"
|
||||||
|
#include "common/utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
// namespace to support composite operators definition
|
||||||
|
namespace prim {
|
||||||
|
|
||||||
|
// Expand the tuple and dict parameters generated when parsing the function call,
|
||||||
|
// and generate positional parameters and key-value pairs for function.
|
||||||
|
class UnpackCall : public MetaFuncGraph {
|
||||||
|
public:
|
||||||
|
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
|
||||||
|
~UnpackCall() override = default;
|
||||||
|
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
|
||||||
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
||||||
|
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
};
|
||||||
|
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
||||||
|
|
||||||
|
} // namespace prim
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
|
@ -246,6 +246,21 @@ class DoSignaturePrimitive : public Primitive {
|
||||||
ValuePtr function_;
|
ValuePtr function_;
|
||||||
};
|
};
|
||||||
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
|
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
|
||||||
|
|
||||||
|
class UnpackGraphPrimitive : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args)
|
||||||
|
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
|
||||||
|
~UnpackGraphPrimitive() override = default;
|
||||||
|
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)
|
||||||
|
bool with_sens_in_args() const { return with_sens_in_args_; }
|
||||||
|
bool need_unpack_args() const { return need_unpack_args_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool with_sens_in_args_;
|
||||||
|
bool need_unpack_args_;
|
||||||
|
};
|
||||||
|
using UnpackGraphPrimitivePtr = std::shared_ptr<UnpackGraphPrimitive>;
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@
|
||||||
#include "optimizer/irpass/specialize_transform.h"
|
#include "optimizer/irpass/specialize_transform.h"
|
||||||
#include "optimizer/irpass/incorporate_getitem.h"
|
#include "optimizer/irpass/incorporate_getitem.h"
|
||||||
#include "optimizer/irpass/incorporate_call.h"
|
#include "optimizer/irpass/incorporate_call.h"
|
||||||
|
#include "optimizer/irpass/grad_var_prepare.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -123,6 +124,11 @@ ResolveIRPassLib::ResolveIRPassLib() {
|
||||||
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
|
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
|
||||||
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
|
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||||
|
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -102,6 +102,13 @@ class ResolveIRPassLib {
|
||||||
SubstitutionPtr resolver_getattr_;
|
SubstitutionPtr resolver_getattr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class InferenceOptPrepareLib {
|
||||||
|
public:
|
||||||
|
InferenceOptPrepareLib();
|
||||||
|
~InferenceOptPrepareLib() = default;
|
||||||
|
SubstitutionPtr grad_var_prepare_;
|
||||||
|
};
|
||||||
|
|
||||||
// predicate functions
|
// predicate functions
|
||||||
inline bool IsNode(const AnfNodePtr &) { return true; }
|
inline bool IsNode(const AnfNodePtr &) { return true; }
|
||||||
|
|
||||||
|
@ -151,6 +158,7 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "optimizer/irpass/grad_var_prepare.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "operator/composite/composite.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "optimizer/irpass.h"
|
||||||
|
#include "optimizer/optimizer.h"
|
||||||
|
#include "ir/visitor.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
|
||||||
|
static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
|
||||||
|
AnfNodePtr func_node, bool is_unpack, bool sens_param) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(func_node);
|
||||||
|
std::vector<AnfNodePtr> nodes;
|
||||||
|
AnfNodePtr unpack_graph_node = nullptr;
|
||||||
|
if (is_unpack) {
|
||||||
|
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true);
|
||||||
|
nodes.push_back(NewValueNode(unpack_graph));
|
||||||
|
nodes.push_back(func_node);
|
||||||
|
// {unpackcall, {GradOperation, ...}, args...}
|
||||||
|
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
|
||||||
|
[](const AnfNodePtr& node) { return node; });
|
||||||
|
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||||
|
} else {
|
||||||
|
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
|
||||||
|
nodes.push_back(NewValueNode(unpack_graph));
|
||||||
|
nodes.push_back(func_node);
|
||||||
|
// {{GradOperation, ...}, args...}
|
||||||
|
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
|
||||||
|
[](const AnfNodePtr& node) { return node; });
|
||||||
|
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||||
|
}
|
||||||
|
return unpack_graph_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get metagraph of value node
|
||||||
|
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
|
||||||
|
ValuePtr value;
|
||||||
|
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
||||||
|
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||||
|
} else {
|
||||||
|
value = GetValueNode(node);
|
||||||
|
}
|
||||||
|
if (value == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return value->cast<MetaFuncGraphPtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if node is a specific metafuncgraph op
|
||||||
|
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) {
|
||||||
|
if (node != nullptr) {
|
||||||
|
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||||
|
if (meta_func_graph_ptr == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// {{GradOperation, g, w}, Ys}
|
||||||
|
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||||
|
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) {
|
||||||
|
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// {{...}, Ys}
|
||||||
|
auto inputs_y = node->cast<CNodePtr>()->inputs();
|
||||||
|
std::vector<AnfNodePtr> inputs_x;
|
||||||
|
if (IsCNode(inputs_y[0])) {
|
||||||
|
inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
|
||||||
|
} else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
|
||||||
|
inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// {{...}, Xs}
|
||||||
|
if (inputs_x.size() < 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// {GradOperation, g, w} or {GradOperation, g}
|
||||||
|
if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
|
||||||
|
if (meta_func == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
|
||||||
|
auto func_node = inputs_x[1];
|
||||||
|
if (!IsValueNode<FuncGraph>(func_node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr unpack_graph_node =
|
||||||
|
GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node,
|
||||||
|
IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param());
|
||||||
|
// constuct new grad_opration
|
||||||
|
inputs_x[1] = unpack_graph_node;
|
||||||
|
auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x);
|
||||||
|
if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
|
||||||
|
inputs_y[1] = grad_op_cnode;
|
||||||
|
} else {
|
||||||
|
inputs_y[0] = grad_op_cnode;
|
||||||
|
}
|
||||||
|
auto cnode = node->func_graph()->NewCNode(inputs_y);
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "operator/composite/composite.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "optimizer/irpass.h"
|
||||||
|
#include "optimizer/optimizer.h"
|
||||||
|
#include "ir/visitor.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
|
||||||
|
// {{GradOperation, g, w}, Ys}
|
||||||
|
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||||
|
class GradVarPrepare : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
GradVarPrepare()
|
||||||
|
: grad_op_(std::make_shared<prim::GradOperation>("grad")),
|
||||||
|
unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {}
|
||||||
|
~GradVarPrepare() override = default;
|
||||||
|
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
MetaFuncGraphPtr grad_op_;
|
||||||
|
MetaFuncGraphPtr unpack_op_;
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
|
@ -175,10 +175,10 @@ bool CombineLikeGraphs(const ResourcePtr&) {
|
||||||
|
|
||||||
bool SymbolResolveAction(const ResourcePtr& res) {
|
bool SymbolResolveAction(const ResourcePtr& res) {
|
||||||
if (res->manager() == nullptr) {
|
if (res->manager() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Resolve error.";
|
MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
|
||||||
}
|
}
|
||||||
if (res->func_graph() == nullptr) {
|
if (res->func_graph() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Resolve error";
|
MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
|
||||||
}
|
}
|
||||||
FuncGraphPtr func_graph = res->func_graph();
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
auto succ = parse::ResolveFuncGraph(func_graph, res);
|
auto succ = parse::ResolveFuncGraph(func_graph, res);
|
||||||
|
@ -194,6 +194,16 @@ bool SymbolResolveAction(const ResourcePtr& res) {
|
||||||
return succ;
|
return succ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InferenceOptPrepareAction(const ResourcePtr& res) {
|
||||||
|
if (res->manager() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
|
||||||
|
}
|
||||||
|
if (res->func_graph() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
|
||||||
|
}
|
||||||
|
return InferenceOptPreparePass(res);
|
||||||
|
}
|
||||||
|
|
||||||
bool AbstractSpecializeAction(const ResourcePtr& res) {
|
bool AbstractSpecializeAction(const ResourcePtr& res) {
|
||||||
if (res->func_graph() == nullptr) {
|
if (res->func_graph() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
||||||
|
@ -331,7 +341,7 @@ static std::vector<ActionItem> CommonPipeline() {
|
||||||
// Resolve the python func
|
// Resolve the python func
|
||||||
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
||||||
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||||
|
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||||
// Evaluate type and shape, and specialize
|
// Evaluate type and shape, and specialize
|
||||||
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||||
|
|
||||||
|
|
|
@ -160,6 +160,13 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) {
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OptPassGroupMap GetInferenceOptPreparePhases() {
|
||||||
|
opt::irpass::InferenceOptPrepareLib irpass;
|
||||||
|
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
|
||||||
|
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
|
||||||
|
return prepare_map;
|
||||||
|
}
|
||||||
|
|
||||||
OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) {
|
OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) {
|
||||||
opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
|
opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
|
||||||
OptPassGroupMap map({{"prepare_group", prepare_group}});
|
OptPassGroupMap map({{"prepare_group", prepare_group}});
|
||||||
|
@ -239,6 +246,16 @@ bool ValidatePass(const ResourcePtr& res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InferenceOptPreparePass(const ResourcePtr& res) {
|
||||||
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
||||||
|
auto prepare_map = GetInferenceOptPreparePhases();
|
||||||
|
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
|
||||||
|
(void)infer_opt_prepare->step(func_graph, args_spec, false);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"opt_a", OptPassAGroup},
|
{"opt_a", OptPassAGroup},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
|
|
|
@ -34,7 +34,7 @@ bool CconvPass(const ResourcePtr& res);
|
||||||
bool ValidatePass(const ResourcePtr& res);
|
bool ValidatePass(const ResourcePtr& res);
|
||||||
bool ConvertPrepareAdapt(const ResourcePtr& res);
|
bool ConvertPrepareAdapt(const ResourcePtr& res);
|
||||||
bool AddControlDependPass(const ResourcePtr& res);
|
bool AddControlDependPass(const ResourcePtr& res);
|
||||||
|
bool InferenceOptPreparePass(const ResourcePtr& res);
|
||||||
void ReclaimOptimizer();
|
void ReclaimOptimizer();
|
||||||
} // namespace pipeline
|
} // namespace pipeline
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -133,6 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
FuncGraphPtr func_graph_;
|
FuncGraphPtr func_graph_;
|
||||||
AnalysisContextPtr context_;
|
AnalysisContextPtr context_;
|
||||||
};
|
};
|
||||||
|
using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
|
||||||
|
|
||||||
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -41,7 +41,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
|
||||||
} else {
|
} else {
|
||||||
oss << "nullptr";
|
oss << "nullptr";
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||||
}
|
}
|
||||||
return NewContext(parent_context, func_graph, args_spec_list);
|
return NewContext(parent_context, func_graph, args_spec_list);
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,85 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
|
||||||
return engine->ForwardConfig(out_conf, fn_conf);
|
return engine->ForwardConfig(out_conf, fn_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
|
||||||
|
// arg[0] is the func graph to unpack, ignore it
|
||||||
|
AbstractBasePtrList sepcialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
|
||||||
|
AbstractBasePtrList graph_sepcialize_args;
|
||||||
|
if (need_unpack) {
|
||||||
|
for (size_t index = 0; index < sepcialize_args_before_unpack.size(); index++) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sepcialize_args_before_unpack[index]);
|
||||||
|
if (sepcialize_args_before_unpack[index]->isa<AbstractTuple>()) {
|
||||||
|
AbstractTuplePtr arg_tuple = sepcialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
|
||||||
|
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
|
||||||
|
std::back_inserter(graph_sepcialize_args), [](AbstractBasePtr abs) { return abs; });
|
||||||
|
} else if (sepcialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
|
||||||
|
AbstractDictionaryPtr arg_dict = sepcialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
|
||||||
|
auto dict_elems = arg_dict->elements();
|
||||||
|
(void)std::transform(
|
||||||
|
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_sepcialize_args),
|
||||||
|
[](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
|
||||||
|
<< sepcialize_args_before_unpack[index]->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
graph_sepcialize_args = sepcialize_args_before_unpack;
|
||||||
|
}
|
||||||
|
return graph_sepcialize_args;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
|
AnfNodeConfigPtr out_conf) {
|
||||||
|
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||||
|
}
|
||||||
|
if (!prim_->isa<prim::UnpackGraphPrimitive>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
|
||||||
|
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||||
|
const auto &out_node_inputs = out_node->inputs();
|
||||||
|
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
|
||||||
|
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
|
||||||
|
<< ", inputs size " << out_node_inputs.size();
|
||||||
|
}
|
||||||
|
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
|
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
||||||
|
// get the forward graph
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
|
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
||||||
|
if (fn == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
|
||||||
|
}
|
||||||
|
auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(real_fn);
|
||||||
|
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||||
|
AbstractBasePtrList graph_sepcialize_args =
|
||||||
|
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
|
||||||
|
|
||||||
|
AbstractBasePtrList graph_sepcialize_args_without_sens;
|
||||||
|
(void)std::transform(graph_sepcialize_args.begin(),
|
||||||
|
graph_sepcialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
|
||||||
|
std::back_inserter(graph_sepcialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
|
||||||
|
auto new_graph = forward_graph->GenerateGraph(graph_sepcialize_args_without_sens);
|
||||||
|
engine->func_graph_manager()->AddFuncGraph(new_graph);
|
||||||
|
ScopePtr scope = kDefaultScope;
|
||||||
|
if (out_conf != nullptr) {
|
||||||
|
scope = out_conf->node()->scope();
|
||||||
|
}
|
||||||
|
ScopeGuard scope_guard(scope);
|
||||||
|
AnfNodePtr new_vnode = NewValueNode(new_graph);
|
||||||
|
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
|
||||||
|
|
||||||
|
return engine->ForwardConfig(out_conf, fn_conf);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
py::object BuildValue(const ValuePtr &value_ptr) {
|
py::object BuildValue(const ValuePtr &value_ptr) {
|
||||||
if (value_ptr == nullptr) {
|
if (value_ptr == nullptr) {
|
||||||
|
|
|
@ -87,6 +87,21 @@ class DoSignatureEvaluator : public Evaluator {
|
||||||
PrimitivePtr prim_;
|
PrimitivePtr prim_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UnpackGraphEvaluator : public Evaluator {
|
||||||
|
public:
|
||||||
|
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
|
||||||
|
~UnpackGraphEvaluator() override = default;
|
||||||
|
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||||
|
AnfNodeConfigPtr out_config = nullptr) override;
|
||||||
|
|
||||||
|
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
|
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
PrimitivePtr prim_;
|
||||||
|
};
|
||||||
|
|
||||||
bool IsInWhiteList(PrimitivePtr primitive);
|
bool IsInWhiteList(PrimitivePtr primitive);
|
||||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||||
|
|
||||||
|
|
|
@ -289,6 +289,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
||||||
evaluator = std::make_shared<DoSignatureEvaluator>(prim);
|
evaluator = std::make_shared<DoSignatureEvaluator>(prim);
|
||||||
return evaluator;
|
return evaluator;
|
||||||
}
|
}
|
||||||
|
if (prim->isa<prim::UnpackGraphPrimitive>()) {
|
||||||
|
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
|
||||||
|
return evaluator;
|
||||||
|
}
|
||||||
if (prim->HasPyEvaluator()) {
|
if (prim->HasPyEvaluator()) {
|
||||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||||
if (prim_py != nullptr) {
|
if (prim_py != nullptr) {
|
||||||
|
|
|
@ -19,6 +19,8 @@ from mindspore.nn import Cell
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
import mindspore.ops.composite as C
|
import mindspore.ops.composite as C
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
@ -34,3 +36,152 @@ def test_net_vargs_expand():
|
||||||
sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
|
sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
|
||||||
net = AddNet()
|
net = AddNet()
|
||||||
out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens)
|
out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens)
|
||||||
|
|
||||||
|
class VarNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(VarNet, self).__init__()
|
||||||
|
self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
|
||||||
|
self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True)
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *args):
|
||||||
|
return self.net(*args)*self.w + self.b
|
||||||
|
|
||||||
|
class SecondNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SecondNet, self).__init__()
|
||||||
|
self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
|
||||||
|
def construct(self, *args):
|
||||||
|
res = args[0] + args[1]
|
||||||
|
return res + self.b2
|
||||||
|
def test_all_var_args_grad_with_sens():
|
||||||
|
""""test grad_by_list_with_sens with all var args input"""
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return C.grad_by_list_with_sens(self.net, self.weights)(*inputs)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
sens = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y, sens)
|
||||||
|
|
||||||
|
def test_grad_list_var_args():
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y)
|
||||||
|
|
||||||
|
def test_grad_all_var_args():
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return C.grad_all(self.net)(*inputs)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y)
|
||||||
|
|
||||||
|
def test_grad_all_var_args_with_sens():
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return C.grad_all_with_sens(self.net)(*inputs)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
sens = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y, sens)
|
||||||
|
|
||||||
|
def test_grad_var_args_with_sens():
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return C.grad_with_sens(self.net)(*inputs)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
sens = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y, sens)
|
||||||
|
|
||||||
|
def test_var_args_grad():
|
||||||
|
class VarNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(VarNet, self).__init__()
|
||||||
|
self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
|
||||||
|
self.net = net
|
||||||
|
def construct(self, *args):
|
||||||
|
return self.net(*args) + self.b
|
||||||
|
|
||||||
|
class SecondNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SecondNet, self).__init__()
|
||||||
|
self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
|
||||||
|
def construct(self, *args):
|
||||||
|
res = args[0] + args[1]
|
||||||
|
return res + self.b2
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
def construct(self, x, y, sens):
|
||||||
|
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
sens = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y, sens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_args_positional():
|
||||||
|
""""test grad_all with var args in inner graph"""
|
||||||
|
class VarNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(VarNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.net(x, y)*x
|
||||||
|
|
||||||
|
class SecondNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SecondNet, self).__init__()
|
||||||
|
def construct(self, *args):
|
||||||
|
return args[0] + args[1]
|
||||||
|
|
||||||
|
class GradNet(Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.weights = ParameterTuple(net.trainable_params())
|
||||||
|
def construct(self, x, y):
|
||||||
|
return C.grad_all(self.net)(x, y)
|
||||||
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||||
|
net = VarNet(SecondNet())
|
||||||
|
grad_net = GradNet(net)
|
||||||
|
out = grad_net(x, y)
|
||||||
|
|
Loading…
Reference in New Issue