forked from mindspore-Ecosystem/mindspore
vmap unpack graph
This commit is contained in:
parent
90accbb2b8
commit
04fe405289
|
@ -20,7 +20,7 @@
|
|||
#include "frontend/optimizer/irpass/cast_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/convert.h"
|
||||
#include "frontend/optimizer/irpass/environ_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/grad_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
|
@ -288,7 +288,7 @@ ResolveIRPassLib::ResolveIRPassLib() {
|
|||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
|
||||
meta_fg_var_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_fg_var_prepare", IsCNode);
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -184,7 +184,7 @@ class InferenceOptPrepareLib {
|
|||
public:
|
||||
InferenceOptPrepareLib();
|
||||
~InferenceOptPrepareLib() = default;
|
||||
SubstitutionPtr grad_var_prepare_;
|
||||
SubstitutionPtr meta_fg_var_prepare_;
|
||||
};
|
||||
|
||||
// predicate functions
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {{GradOperation, g, w}, Ys}
|
||||
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||
class GradVarPrepare : public AnfVisitor {
|
||||
public:
|
||||
GradVarPrepare()
|
||||
: grad_op_(std::make_shared<prim::GradOperation>("grad")),
|
||||
unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {}
|
||||
~GradVarPrepare() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
MetaFuncGraphPtr grad_op_;
|
||||
MetaFuncGraphPtr unpack_op_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/grad_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
@ -29,6 +29,15 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// Get meta_fg_ops registration white list
|
||||
static const std::vector<MatcherPtr> &GetMetaFgOps() {
|
||||
static const std::vector<MatcherPtr> meta_fg_ops{
|
||||
std::make_shared<MetaFgMatcher<prim::GradOperation>>(),
|
||||
std::make_shared<MetaFgMatcher<prim::VmapOperation>>(),
|
||||
};
|
||||
return meta_fg_ops;
|
||||
}
|
||||
|
||||
static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::vector<AnfNodePtr> inputs_y,
|
||||
const AnfNodePtr &func_node, bool is_unpack, bool sens_param) {
|
||||
MS_EXCEPTION_IF_NULL(func_node);
|
||||
|
@ -40,7 +49,7 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
|
|||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
// {unpackcall, {GradOperation, ...}, args...}
|
||||
// {unpackcall, {GradOperation, ...}, args...} and other {unpackcall, {meta_fg_opration, ...}, args...}
|
||||
const size_t inputs_begin_index = 2;
|
||||
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr &node) { return node; });
|
||||
|
@ -49,7 +58,7 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
|
|||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
// {{GradOperation, ...}, args...}
|
||||
// {{GradOperation, ...}, args...} and other {{meta_fg_opration, ...}, args...}
|
||||
const size_t inputs_begin_index = 1;
|
||||
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr &node) { return node; });
|
||||
|
@ -58,7 +67,6 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
|
|||
return unpack_graph_node;
|
||||
}
|
||||
|
||||
// get metagraph of value node
|
||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
|
||||
ValuePtr value;
|
||||
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
||||
|
@ -72,24 +80,29 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
|
|||
return value->cast<MetaFuncGraphPtr>();
|
||||
}
|
||||
|
||||
// check if node is a specific metafuncgraph op
|
||||
bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
|
||||
if (node != nullptr) {
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// check if node is a specific meta_fg_opration that registered in the meta_fg_ops
|
||||
bool CheckMetaFgOps(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &meta_fg_ops = GetMetaFgOps();
|
||||
for (auto meta_fg_op : meta_fg_ops) {
|
||||
if (meta_fg_op->Match(meta_func_graph_ptr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// {{GradOperation, g, w}, Ys}
|
||||
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
// {{GradOperation, g, w}, Ys}, {UnPackCall, {GradOperation, g, w}, Ys},
|
||||
// and other {{meta_fg_opration, ...}, ...} or {UnPackCall, {meta_fg_opration, ...}, ...}
|
||||
AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -104,41 +117,47 @@ AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &no
|
|||
std::vector<AnfNodePtr> inputs_x;
|
||||
if (IsCNode(inputs_y[0])) {
|
||||
inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
|
||||
} else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
|
||||
} else if (unpack_op_->Match(inputs_y[0]) && IsCNode(inputs_y[1])) {
|
||||
inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {{...}, Xs}
|
||||
if (inputs_x.size() < 2) {
|
||||
const size_t inputs_x_minimum_size = 2;
|
||||
if (inputs_x.size() < inputs_x_minimum_size) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {GradOperation, g, w} or {GradOperation, g}
|
||||
if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
|
||||
if (!CheckMetaFgOps(inputs_x[0])) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
|
||||
if (meta_func == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
|
||||
auto func_node = inputs_x[1];
|
||||
if (!IsValueNode<FuncGraph>(func_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const bool is_unpack = IsMetaFuncGraph(inputs_y[0], unpack_op_);
|
||||
const bool sens_param = grad_op_ptr->sens_param();
|
||||
const bool is_unpack = unpack_op_->Match(inputs_y[0]);
|
||||
|
||||
// For general meta_fg_opration, ‘sens_param’ is not involved, and that of GradOperation obtained specifically.
|
||||
bool sens_param = false;
|
||||
if (grad_op_->Match(inputs_x[0])) {
|
||||
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
|
||||
if (meta_func == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
|
||||
sens_param = grad_op_ptr->sens_param();
|
||||
}
|
||||
|
||||
inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param);
|
||||
// construct new grad_opration
|
||||
auto grad_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
|
||||
if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
|
||||
inputs_y[1] = grad_op_cnode;
|
||||
// construct new meta_fg_opration
|
||||
auto meta_fg_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
|
||||
if (unpack_op_->Match(inputs_y[0])) {
|
||||
inputs_y[1] = meta_fg_op_cnode;
|
||||
} else {
|
||||
inputs_y[0] = grad_op_cnode;
|
||||
inputs_y[0] = meta_fg_op_cnode;
|
||||
}
|
||||
return func_graph->NewCNodeBefore(node, inputs_y);
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// get metagraph of value node
|
||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node);
|
||||
|
||||
class Matcher {
|
||||
public:
|
||||
Matcher() {}
|
||||
virtual ~Matcher() = default;
|
||||
|
||||
virtual bool Match(const MetaFuncGraphPtr &meta_fg_ptr) const = 0;
|
||||
virtual bool Match(const AnfNodePtr &node) const = 0;
|
||||
};
|
||||
using MatcherPtr = std::shared_ptr<Matcher>;
|
||||
|
||||
// MetaFgMatcher is used to check whether the object is a specific meta_fg_opration
|
||||
template <typename T>
|
||||
class MetaFgMatcher : public Matcher {
|
||||
public:
|
||||
MetaFgMatcher() {}
|
||||
~MetaFgMatcher() override = default;
|
||||
|
||||
bool Match(const MetaFuncGraphPtr &meta_fg_ptr) const override { return meta_fg_ptr->isa<T>(); }
|
||||
|
||||
bool Match(const AnfNodePtr &node) const override {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return meta_func_graph_ptr->isa<T>();
|
||||
}
|
||||
};
|
||||
|
||||
// Complete the preparation of MetaFuncGraph's variables.
|
||||
// 1) Handle the varying number of arguments of the MetaFuncGraph.
|
||||
// eg.grad(fn)(*args) or vmap(fn)(*args), where fn(*args).
|
||||
// 2) Handle the case of the sens_param of GradOperation customized by users.
|
||||
// eg.GradOperation(sens_param = True)(net)(*real_inputs, sense_para_inputs)
|
||||
class MetaFgVarPrepare : public AnfVisitor {
|
||||
public:
|
||||
MetaFgVarPrepare()
|
||||
: grad_op_(std::make_shared<MetaFgMatcher<prim::GradOperation>>()),
|
||||
unpack_op_(std::make_shared<MetaFgMatcher<prim::UnpackCall>>()) {}
|
||||
~MetaFgVarPrepare() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
MatcherPtr grad_op_;
|
||||
MatcherPtr unpack_op_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_
|
|
@ -542,8 +542,8 @@ OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPa
|
|||
|
||||
OptPassGroupMap GetInferenceOptPreparePhases() {
|
||||
opt::irpass::InferenceOptPrepareLib irpass;
|
||||
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
|
||||
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
|
||||
auto meta_fg_var_prepare = opt::OptPassConfig({irpass.meta_fg_var_prepare_});
|
||||
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", meta_fg_var_prepare}});
|
||||
return prepare_map;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue