vmap unpack graph

This commit is contained in:
Erpim 2022-04-21 16:45:58 +08:00
parent 90accbb2b8
commit 04fe405289
6 changed files with 147 additions and 91 deletions

View File

@ -20,7 +20,7 @@
#include "frontend/optimizer/irpass/cast_eliminate.h"
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
@ -288,7 +288,7 @@ ResolveIRPassLib::ResolveIRPassLib() {
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
meta_fg_var_prepare_ = MakeSubstitution(std::make_shared<MetaFgVarPrepare>(), "meta_fg_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt

View File

@ -184,7 +184,7 @@ class InferenceOptPrepareLib {
public:
InferenceOptPrepareLib();
~InferenceOptPrepareLib() = default;
SubstitutionPtr grad_var_prepare_;
SubstitutionPtr meta_fg_var_prepare_;
};
// predicate functions

View File

@ -1,54 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include "utils/hash_map.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class GradVarPrepare : public AnfVisitor {
public:
GradVarPrepare()
: grad_op_(std::make_shared<prim::GradOperation>("grad")),
unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {}
~GradVarPrepare() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
private:
MetaFuncGraphPtr grad_op_;
MetaFuncGraphPtr unpack_op_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/meta_fg_var_prepare.h"
#include <vector>
#include <algorithm>
#include <memory>
@ -29,6 +29,15 @@
namespace mindspore {
namespace opt {
namespace irpass {
// Get meta_fg_ops registration white list
static const std::vector<MatcherPtr> &GetMetaFgOps() {
static const std::vector<MatcherPtr> meta_fg_ops{
std::make_shared<MetaFgMatcher<prim::GradOperation>>(),
std::make_shared<MetaFgMatcher<prim::VmapOperation>>(),
};
return meta_fg_ops;
}
static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::vector<AnfNodePtr> inputs_y,
const AnfNodePtr &func_node, bool is_unpack, bool sens_param) {
MS_EXCEPTION_IF_NULL(func_node);
@ -40,7 +49,7 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, true);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
// {unpackcall, {GradOperation, ...}, args...}
// {unpackcall, {GradOperation, ...}, args...} and other {unpackcall, {meta_fg_opration, ...}, args...}
const size_t inputs_begin_index = 2;
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr &node) { return node; });
@ -49,7 +58,7 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>(sens_param, false);
nodes.push_back(NewValueNode(unpack_graph));
nodes.push_back(func_node);
// {{GradOperation, ...}, args...}
// {{GradOperation, ...}, args...} and other {{meta_fg_opration, ...}, args...}
const size_t inputs_begin_index = 1;
(void)std::transform(inputs_y.begin() + inputs_begin_index, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr &node) { return node; });
@ -58,7 +67,6 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
return unpack_graph_node;
}
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
@ -72,24 +80,29 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
return value->cast<MetaFuncGraphPtr>();
}
// check if node is a specific metafuncgraph op
bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
if (node != nullptr) {
// check if node is a specific meta_fg_opration that registered in the meta_fg_ops
bool CheckMetaFgOps(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
if (meta_func_graph_ptr == nullptr) {
return false;
}
if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
const auto &meta_fg_ops = GetMetaFgOps();
for (auto meta_fg_op : meta_fg_ops) {
if (meta_fg_op->Match(meta_func_graph_ptr)) {
return true;
}
}
return false;
}
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
// {{GradOperation, g, w}, Ys}, {UnPackCall, {GradOperation, g, w}, Ys},
// and other {{meta_fg_opration, ...}, ...} or {UnPackCall, {meta_fg_opration, ...}, ...}
AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
@ -104,41 +117,47 @@ AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &no
std::vector<AnfNodePtr> inputs_x;
if (IsCNode(inputs_y[0])) {
inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
} else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
} else if (unpack_op_->Match(inputs_y[0]) && IsCNode(inputs_y[1])) {
inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
} else {
return nullptr;
}
// {{...}, Xs}
if (inputs_x.size() < 2) {
const size_t inputs_x_minimum_size = 2;
if (inputs_x.size() < inputs_x_minimum_size) {
return nullptr;
}
// {GradOperation, g, w} or {GradOperation, g}
if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
if (!CheckMetaFgOps(inputs_x[0])) {
return nullptr;
}
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
if (meta_func == nullptr) {
return nullptr;
}
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
auto func_node = inputs_x[1];
if (!IsValueNode<FuncGraph>(func_node)) {
return nullptr;
}
const bool is_unpack = IsMetaFuncGraph(inputs_y[0], unpack_op_);
const bool sens_param = grad_op_ptr->sens_param();
const bool is_unpack = unpack_op_->Match(inputs_y[0]);
// For general meta_fg_opration, sens_param is not involved, and that of GradOperation obtained specifically.
bool sens_param = false;
if (grad_op_->Match(inputs_x[0])) {
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
if (meta_func == nullptr) {
return nullptr;
}
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
sens_param = grad_op_ptr->sens_param();
}
inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param);
// construct new grad_opration
auto grad_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
inputs_y[1] = grad_op_cnode;
// construct new meta_fg_opration
auto meta_fg_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
if (unpack_op_->Match(inputs_y[0])) {
inputs_y[1] = meta_fg_op_cnode;
} else {
inputs_y[0] = grad_op_cnode;
inputs_y[0] = meta_fg_op_cnode;
}
return func_graph->NewCNodeBefore(node, inputs_y);
}

View File

@ -0,0 +1,91 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include "utils/hash_map.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
namespace irpass {
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node);
class Matcher {
public:
Matcher() {}
virtual ~Matcher() = default;
virtual bool Match(const MetaFuncGraphPtr &meta_fg_ptr) const = 0;
virtual bool Match(const AnfNodePtr &node) const = 0;
};
using MatcherPtr = std::shared_ptr<Matcher>;
// MetaFgMatcher is used to check whether the object is a specific meta_fg_opration
template <typename T>
class MetaFgMatcher : public Matcher {
public:
MetaFgMatcher() {}
~MetaFgMatcher() override = default;
bool Match(const MetaFuncGraphPtr &meta_fg_ptr) const override { return meta_fg_ptr->isa<T>(); }
bool Match(const AnfNodePtr &node) const override {
if (node == nullptr) {
return false;
}
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
if (meta_func_graph_ptr == nullptr) {
return false;
}
return meta_func_graph_ptr->isa<T>();
}
};
// Complete the preparation of MetaFuncGraph's variables.
// 1) Handle the varying number of arguments of the MetaFuncGraph.
// eg.grad(fn)(*args) or vmap(fn)(*args), where fn(*args).
// 2) Handle the case of the sens_param of GradOperation customized by users.
// eg.GradOperation(sens_param = True)(net)(*real_inputs, sense_para_inputs)
class MetaFgVarPrepare : public AnfVisitor {
public:
MetaFgVarPrepare()
: grad_op_(std::make_shared<MetaFgMatcher<prim::GradOperation>>()),
unpack_op_(std::make_shared<MetaFgMatcher<prim::UnpackCall>>()) {}
~MetaFgVarPrepare() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
private:
MatcherPtr grad_op_;
MatcherPtr unpack_op_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_VAR_PREPARE_H_

View File

@ -542,8 +542,8 @@ OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPa
OptPassGroupMap GetInferenceOptPreparePhases() {
opt::irpass::InferenceOptPrepareLib irpass;
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
auto meta_fg_var_prepare = opt::OptPassConfig({irpass.meta_fg_var_prepare_});
opt::OptPassGroupMap prepare_map({{"inference_opt_prep", meta_fg_var_prepare}});
return prepare_map;
}