!62 support grad on python function with variable arguments

Merge pull request !62 from amongo/SupportGradOnVarArgs
This commit is contained in:
mindspore-ci-bot 2020-04-02 11:21:06 +08:00 committed by Gitee
commit cf54ecfe6e
18 changed files with 660 additions and 67 deletions

View File

@ -1199,51 +1199,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
return ret_graph; return ret_graph;
} }
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
const std::string op_name = std::string("UnpackCall");
size_t arg_length = args_spec_list.size();
if (arg_length < 2) {
MS_LOG(EXCEPTION) << "" << op_name << " requires at least two args, but got " << arg_length << ".";
}
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr fnNode = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(fnNode);
for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) {
AbstractTuplePtr arg_tuple = dyn_cast<AbstractTuple>(args_spec_list[index]);
AnfNodePtr para_tuple = ret_graph->add_parameter();
for (size_t i = 0; i < arg_tuple->size(); ++i) {
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
}
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = dyn_cast<AbstractDictionary>(args_spec_list[index]);
AnfNodePtr para_dict = ret_graph->add_parameter();
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
[ret_graph, para_dict](const AbstractAttribute& item) {
return ret_graph->NewCNode(
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first),
ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)})});
});
} else {
MS_LOG(EXCEPTION) << "" << op_name << " require args should be tuple or dict, but got "
<< args_spec_list[index]->ToString();
}
}
ret_graph->set_output(ret_graph->NewCNode(elems));
return ret_graph;
}
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
TupleAdd_, ([](const py::module* m) { TupleAdd_, ([](const py::module* m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>());
@ -1258,10 +1213,5 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) {
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
.def(py::init<std::string&>()); .def(py::init<std::string&>());
})); }));
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
.def(py::init<std::string&>());
}));
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "operator/composite/zip_operation.h" #include "operator/composite/zip_operation.h"
#include "operator/composite/list_append_operation.h" #include "operator/composite/list_append_operation.h"
#include "operator/composite/do_signature.h" #include "operator/composite/do_signature.h"
#include "operator/composite/unpack_call.h"
#include "pipeline/static_analysis/static_analysis.h" #include "pipeline/static_analysis/static_analysis.h"
#include "utils/misc.h" #include "utils/misc.h"
#include "utils/any.h" #include "utils/any.h"
@ -154,7 +155,7 @@ class GradOperation : public MetaFuncGraph {
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams,
bool applyJ = false); bool applyJ = false);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
bool sens_param() const { return sens_param_; }
bool get_all_; bool get_all_;
bool get_by_list_; bool get_by_list_;
bool sens_param_; bool sens_param_;
@ -208,17 +209,6 @@ class TensorSlice : public MetaFuncGraph {
}; };
using TensorSlicePtr = std::shared_ptr<TensorSlice>; using TensorSlicePtr = std::shared_ptr<TensorSlice>;
// Expand the tuple and dict parameters generated when parsing the function call,
// and generate positional parameters and key-value pairs for function.
class UnpackCall : public MetaFuncGraph {
public:
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
~UnpackCall() override = default;
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
};
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "operator/composite/unpack_call.h"
#include <algorithm>
#include <utility>
#include "./common.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pipeline/static_analysis/dshape.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "ir/anf.h"
#include "optimizer/opt.h"
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
const std::string op_name = std::string("UnpackCall");
size_t arg_length = args_spec_list.size();
if (arg_length < 2) {
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
}
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr fnNode = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(fnNode);
for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) {
auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
AnfNodePtr para_tuple = ret_graph->add_parameter();
for (size_t i = 0; i < arg_tuple->size(); ++i) {
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
}
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
AnfNodePtr para_dict = ret_graph->add_parameter();
auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
[ret_graph, para_dict](const AbstractAttribute& item) {
auto dict_get_item = ret_graph->NewCNode(
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
return ret_graph->NewCNode(
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
});
} else {
MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got "
<< args_spec_list[index]->ToString();
}
}
ret_graph->set_output(ret_graph->NewCNode(elems));
return ret_graph;
}
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
.def(py::init<std::string&>());
}));
} // namespace prim
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
#include <vector>
#include <string>
#include <unordered_map>
#include <utility>
#include <map>
#include <set>
#include <memory>
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/misc.h"
#include "utils/any.h"
#include "ir/dtype.h"
#include "ir/meta_func_graph.h"
#include "common/utils.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
// Expand the tuple and dict parameters generated when parsing the function call,
// and generate positional parameters and key-value pairs for function.
class UnpackCall : public MetaFuncGraph {
public:
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
~UnpackCall() override = default;
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
};
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_

View File

@ -246,6 +246,21 @@ class DoSignaturePrimitive : public Primitive {
ValuePtr function_; ValuePtr function_;
}; };
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args)
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
~UnpackGraphPrimitive() override = default;
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)
bool with_sens_in_args() const { return with_sens_in_args_; }
bool need_unpack_args() const { return need_unpack_args_; }
private:
bool with_sens_in_args_;
bool need_unpack_args_;
};
using UnpackGraphPrimitivePtr = std::shared_ptr<UnpackGraphPrimitive>;
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -39,6 +39,7 @@
#include "optimizer/irpass/specialize_transform.h" #include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -123,6 +124,11 @@ ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
} }
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
}
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -102,6 +102,13 @@ class ResolveIRPassLib {
SubstitutionPtr resolver_getattr_; SubstitutionPtr resolver_getattr_;
}; };
class InferenceOptPrepareLib {
public:
InferenceOptPrepareLib();
~InferenceOptPrepareLib() = default;
SubstitutionPtr grad_var_prepare_;
};
// predicate functions // predicate functions
inline bool IsNode(const AnfNodePtr &) { return true; } inline bool IsNode(const AnfNodePtr &) { return true; }
@ -151,6 +158,7 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
} }
return false; return false;
} }
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,144 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/irpass/grad_var_prepare.h"
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include "operator/composite/composite.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
namespace irpass {
static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
AnfNodePtr func_node, bool is_unpack, bool sens_param) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(func_node);
std::vector<AnfNodePtr> nodes;
AnfNodePtr unpack_graph_node = nullptr;
if (is_unpack) {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
// {unpackcall, {GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes);
} else {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
// {{GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes);
}
return unpack_graph_node;
}
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
} else {
value = GetValueNode(node);
}
if (value == nullptr) {
return nullptr;
}
return value->cast<MetaFuncGraphPtr>();
}
// check if node is a specific metafuncgraph op
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) {
if (node != nullptr) {
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
if (meta_func_graph_ptr == nullptr) {
return false;
}
if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
return true;
}
}
return false;
}
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
// {{...}, Ys}
auto inputs_y = node->cast<CNodePtr>()->inputs();
std::vector<AnfNodePtr> inputs_x;
if (IsCNode(inputs_y[0])) {
inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
} else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
} else {
return nullptr;
}
// {{...}, Xs}
if (inputs_x.size() < 2) {
return nullptr;
}
// {GradOperation, g, w} or {GradOperation, g}
if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
return nullptr;
}
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
if (meta_func == nullptr) {
return nullptr;
}
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
auto func_node = inputs_x[1];
if (!IsValueNode<FuncGraph>(func_node)) {
return nullptr;
}
AnfNodePtr unpack_graph_node =
GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node,
IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param());
// constuct new grad_opration
inputs_x[1] = unpack_graph_node;
auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x);
if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
inputs_y[1] = grad_op_cnode;
} else {
inputs_y[0] = grad_op_cnode;
}
auto cnode = node->func_graph()->NewCNode(inputs_y);
return cnode;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include "operator/composite/composite.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class GradVarPrepare : public AnfVisitor {
public:
GradVarPrepare()
: grad_op_(std::make_shared<prim::GradOperation>("grad")),
unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {}
~GradVarPrepare() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
private:
MetaFuncGraphPtr grad_op_;
MetaFuncGraphPtr unpack_op_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_

View File

@ -175,10 +175,10 @@ bool CombineLikeGraphs(const ResourcePtr&) {
bool SymbolResolveAction(const ResourcePtr& res) { bool SymbolResolveAction(const ResourcePtr& res) {
if (res->manager() == nullptr) { if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Resolve error."; MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
} }
if (res->func_graph() == nullptr) { if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Resolve error"; MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
} }
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto succ = parse::ResolveFuncGraph(func_graph, res); auto succ = parse::ResolveFuncGraph(func_graph, res);
@ -194,6 +194,16 @@ bool SymbolResolveAction(const ResourcePtr& res) {
return succ; return succ;
} }
bool InferenceOptPrepareAction(const ResourcePtr& res) {
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
}
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
}
return InferenceOptPreparePass(res);
}
bool AbstractSpecializeAction(const ResourcePtr& res) { bool AbstractSpecializeAction(const ResourcePtr& res) {
if (res->func_graph() == nullptr) { if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "AbstractSpecialize error"; MS_LOG(EXCEPTION) << "AbstractSpecialize error";
@ -331,7 +341,7 @@ static std::vector<ActionItem> CommonPipeline() {
// Resolve the python func // Resolve the python func
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize // Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));

View File

@ -160,6 +160,13 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) {
return map; return map;
} }
OptPassGroupMap GetInferenceOptPreparePhases() {
opt::irpass::InferenceOptPrepareLib irpass;
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
return prepare_map;
}
OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) {
opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
OptPassGroupMap map({{"prepare_group", prepare_group}}); OptPassGroupMap map({{"prepare_group", prepare_group}});
@ -239,6 +246,16 @@ bool ValidatePass(const ResourcePtr& res) {
return true; return true;
} }
bool InferenceOptPreparePass(const ResourcePtr& res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
abstract::AbstractBasePtrList args_spec = res->args_spec();
auto prepare_map = GetInferenceOptPreparePhases();
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
(void)infer_opt_prepare->step(func_graph, args_spec, false);
return true;
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},

View File

@ -34,7 +34,7 @@ bool CconvPass(const ResourcePtr& res);
bool ValidatePass(const ResourcePtr& res); bool ValidatePass(const ResourcePtr& res);
bool ConvertPrepareAdapt(const ResourcePtr& res); bool ConvertPrepareAdapt(const ResourcePtr& res);
bool AddControlDependPass(const ResourcePtr& res); bool AddControlDependPass(const ResourcePtr& res);
bool InferenceOptPreparePass(const ResourcePtr& res);
void ReclaimOptimizer(); void ReclaimOptimizer();
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -133,6 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
FuncGraphPtr func_graph_; FuncGraphPtr func_graph_;
AnalysisContextPtr context_; AnalysisContextPtr context_;
}; };
using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
public: public:

View File

@ -41,7 +41,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
} else { } else {
oss << "nullptr"; oss << "nullptr";
} }
MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
} }
return NewContext(parent_context, func_graph, args_spec_list); return NewContext(parent_context, func_graph, args_spec_list);
} }

View File

@ -180,6 +180,85 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
return engine->ForwardConfig(out_conf, fn_conf); return engine->ForwardConfig(out_conf, fn_conf);
} }
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
// arg[0] is the func graph to unpack, ignore it
AbstractBasePtrList sepcialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
AbstractBasePtrList graph_sepcialize_args;
if (need_unpack) {
for (size_t index = 0; index < sepcialize_args_before_unpack.size(); index++) {
MS_EXCEPTION_IF_NULL(sepcialize_args_before_unpack[index]);
if (sepcialize_args_before_unpack[index]->isa<AbstractTuple>()) {
AbstractTuplePtr arg_tuple = sepcialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
std::back_inserter(graph_sepcialize_args), [](AbstractBasePtr abs) { return abs; });
} else if (sepcialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = sepcialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_sepcialize_args),
[](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
} else {
MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
<< sepcialize_args_before_unpack[index]->ToString();
}
}
} else {
graph_sepcialize_args = sepcialize_args_before_unpack;
}
return graph_sepcialize_args;
}
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
if (!prim_->isa<prim::UnpackGraphPrimitive>()) {
MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString();
}
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
auto out_node = out_conf->node()->cast<CNodePtr>();
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
}
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
// get the forward graph
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
}
auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
AbstractBasePtrList graph_sepcialize_args =
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
AbstractBasePtrList graph_sepcialize_args_without_sens;
(void)std::transform(graph_sepcialize_args.begin(),
graph_sepcialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
std::back_inserter(graph_sepcialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
auto new_graph = forward_graph->GenerateGraph(graph_sepcialize_args_without_sens);
engine->func_graph_manager()->AddFuncGraph(new_graph);
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
ScopeGuard scope_guard(scope);
AnfNodePtr new_vnode = NewValueNode(new_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
return engine->ForwardConfig(out_conf, fn_conf);
}
namespace { namespace {
py::object BuildValue(const ValuePtr &value_ptr) { py::object BuildValue(const ValuePtr &value_ptr) {
if (value_ptr == nullptr) { if (value_ptr == nullptr) {

View File

@ -87,6 +87,21 @@ class DoSignatureEvaluator : public Evaluator {
PrimitivePtr prim_; PrimitivePtr prim_;
}; };
class UnpackGraphEvaluator : public Evaluator {
public:
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
~UnpackGraphEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called";
}
private:
PrimitivePtr prim_;
};
bool IsInWhiteList(PrimitivePtr primitive); bool IsInWhiteList(PrimitivePtr primitive);
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);

View File

@ -289,6 +289,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
evaluator = std::make_shared<DoSignatureEvaluator>(prim); evaluator = std::make_shared<DoSignatureEvaluator>(prim);
return evaluator; return evaluator;
} }
if (prim->isa<prim::UnpackGraphPrimitive>()) {
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
return evaluator;
}
if (prim->HasPyEvaluator()) { if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim); auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) { if (prim_py != nullptr) {

View File

@ -19,6 +19,8 @@ from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.ops.composite as C import mindspore.ops.composite as C
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common.parameter import ParameterTuple
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -34,3 +36,152 @@ def test_net_vargs_expand():
sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
net = AddNet() net = AddNet()
out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens) out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens)
class VarNet(Cell):
def __init__(self, net):
super(VarNet, self).__init__()
self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True)
self.net = net
def construct(self, *args):
return self.net(*args)*self.w + self.b
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
def construct(self, *args):
res = args[0] + args[1]
return res + self.b2
def test_all_var_args_grad_with_sens():
""""test grad_by_list_with_sens with all var args input"""
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
def construct(self, *inputs):
return C.grad_by_list_with_sens(self.net, self.weights)(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y, sens)
def test_grad_list_var_args():
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y)
def test_grad_all_var_args():
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y)
def test_grad_all_var_args_with_sens():
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
def construct(self, *inputs):
return C.grad_all_with_sens(self.net)(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y, sens)
def test_grad_var_args_with_sens():
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
def construct(self, *inputs):
return C.grad_with_sens(self.net)(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y, sens)
def test_var_args_grad():
class VarNet(Cell):
def __init__(self, net):
super(VarNet, self).__init__()
self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
self.net = net
def construct(self, *args):
return self.net(*args) + self.b
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
def construct(self, *args):
res = args[0] + args[1]
return res + self.b2
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y, sens):
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y, sens)
def test_var_args_positional():
""""test grad_all with var args in inner graph"""
class VarNet(Cell):
def __init__(self, net):
super(VarNet, self).__init__()
self.net = net
def construct(self, x, y):
return self.net(x, y)*x
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
def construct(self, *args):
return args[0] + args[1]
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y):
return C.grad_all(self.net)(x, y)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y)