grad op for pynative

This commit is contained in:
zhousiyi 2021-02-20 01:06:36 +00:00 committed by chujinjin
parent f9c4cd0cba
commit 0d6a5ccfe4
7 changed files with 334 additions and 2 deletions

View File

@ -64,6 +64,8 @@ AnfNodePtr Adjoint::primal() { return primal_; }
AnfNodePtr Adjoint::dout() { return dout_hole_; }
AnfNodePtr Adjoint::RealDout() { return dout_; }
void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
dout_user_.emplace_back(std::make_pair(user, index));
}

View File

@ -35,6 +35,7 @@ class Adjoint {
void UpdateK(const AnfNodePtr &k);
void RegisterKUser(const CNodePtr &user, size_t index);
AnfNodePtr dout();
AnfNodePtr RealDout();
void AccumulateDout(const AnfNodePtr &dout_factor);
void RegisterDoutUser(const CNodePtr &user, size_t index);
void CallDoutHole();

View File

@ -147,9 +147,9 @@ class KPrim {
bprop_registry_meta_.clear();
bprop_registry_.clear();
}
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
private:
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);

View File

@ -38,7 +38,6 @@
namespace mindspore {
namespace ad {
using PatternListType = std::initializer_list<BaseRef>;
KPrim g_k_prims;
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {

View File

@ -0,0 +1,196 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "frontend/optimizer/ad/adjoint.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/ad/kpynative.h"
#include "frontend/operator/ops.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/ms_context.h"
#include "utils/info.h"
#include "debug/trace.h"
namespace mindspore {
namespace ad {
extern KPrim g_k_prims;
class KPynativeCellImpl : public KPynativeCell {
public:
explicit KPynativeCellImpl(const AnfNodePtrList &cell_inputs) : cell_inputs_(cell_inputs) {
tape_ = std::make_shared<FuncGraph>();
for (size_t i = 0; i < cell_inputs.size(); ++i) {
tape_->add_parameter();
}
}
~KPynativeCellImpl() override = default;
bool KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out);
FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
private:
FuncGraphPtr tape_;
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
AnfNodePtrList cell_inputs_;
// Last cnode of this Cell, may be a primitve op or cell with user defined bprop.
AnfNodePtr last_node_;
bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
};
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs) {
return std::make_shared<KPynativeCellImpl>(cell_inputs);
}
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
bool grad_weights) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->Finish(weights, grad_inputs, grad_weights);
}
FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
for (size_t i = 0; i < weights.size(); ++i) {
tape_->add_parameter();
}
// sens parameter;
auto sens_param = tape_->add_parameter();
auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString();
}
// Set dout of last node to sens;
last_node_adjoint_iter->second->AccumulateDout(sens_param);
// Replace dout hole of all adjoint.
for (auto &adjoint_iter : anfnode_to_adjoin_) {
adjoint_iter.second->CallDoutHole();
}
// Return the gradient;
AnfNodePtrList node_list{NewValueNode(prim::kPrimMakeTuple)};
if (grad_inputs) {
for (auto input : cell_inputs_) {
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << input->ToString();
}
node_list.push_back(input_adjoint_iter->second->RealDout());
}
}
if (grad_weights) {
for (auto weight : weights) {
auto input_adjoint_iter = anfnode_to_adjoin_.find(weight);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << weight->ToString();
}
node_list.push_back(input_adjoint_iter->second->RealDout());
}
}
auto tape_output = tape_->NewCNode(node_list);
tape_->set_output(tape_output);
// Replace AnfNode with parameter of tape_;
auto mng = MakeManager({tape_}, false);
auto tr = mng->Transact();
const auto &parameters = tape_->parameters();
for (size_t i = 0; i < cell_inputs_.size(); ++i) {
tr.Replace(cell_inputs_[i], parameters[i]);
}
for (size_t i = 0; i < weights.size(); ++i) {
tr.Replace(weights[i], parameters[cell_inputs_.size() + i]);
}
tr.Commit();
return tape_;
}
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->KPynativeOp(c_node, op_args, out);
}
bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app) {
for (size_t i = 1; i < cnode_primal->size(); i++) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i - 1))});
auto input = cnode_primal->input(i);
// Backprop sens wrt inputs.
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
}
input_adjoint_iter->second->AccumulateDout(din);
}
return true;
}
bool KPynativeCellImpl::KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out) {
MS_EXCEPTION_IF_NULL(c_node);
auto prim = GetCNodePrimitive(c_node);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "should be primitive, but: " << c_node->DebugString();
}
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(c_node);
if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "CNode should be unique, but: " << c_node->DebugString();
}
// Book-keeping last cnode, as dout of this node will be given from outside;
last_node_ = c_node;
auto cnode_adjoint = std::make_shared<Adjoint>(c_node, NewValueNode(out), tape_);
anfnode_to_adjoin_.emplace(c_node, cnode_adjoint);
auto bprop_fg = g_k_prims.GetBprop(prim);
MS_EXCEPTION_IF_NULL(bprop_fg);
// Optimize the bprop_fg based on value.
auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
AnfNodePtrList node_list{NewValueNode(optimized_bprop_fg)};
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
auto inp_i = c_node->input(i);
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i);
if (anfnode_adjoint_iter == anfnode_to_adjoin_.end()) {
if (inp_i->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot find adjoint for anfnode: " << inp_i->DebugString();
} else {
auto inp_i_adjoint = std::make_shared<Adjoint>(inp_i, NewValueNode(op_args[i - 1]), tape_);
anfnode_to_adjoin_.emplace(inp_i, inp_i_adjoint);
}
}
node_list.push_back(inp_i);
}
node_list.push_back(NewValueNode(out));
node_list.push_back(cnode_adjoint->dout());
auto bprop_app = tape_->NewCNode(node_list);
cnode_adjoint->RegisterDoutUser(bprop_app, node_list.size() - 1);
BackPropagate(c_node, bprop_app);
return true;
}
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out) {
return bprop_fg;
}
} // namespace ad
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#include <memory>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace ad {
class KPynativeCell {
public:
virtual ~KPynativeCell() = default;
};
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out);
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs);
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
bool grad_weights);
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out);
} // namespace ad
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_

View File

@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <unordered_map>
#include "frontend/optimizer/ad/kpynative.h"
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/func_graph_cloner.h"
#include "utils/log_adapter.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/parse.h"
#include "debug/anf_ir_utils.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace ad {
class TestKPynative : public UT::Common {
public:
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
protected:
AbstractBasePtr BuildArg() {
std::vector<int64_t> shp = {2, 2};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
auto abstract = tensor->ToAbstract();
return abstract;
}
FuncGraphPtr BuildPrimalFuncGraph(const std::string& testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
auto y = g->add_parameter();
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
c_node->set_abstract(BuildArg());
g->set_output(c_node);
return g;
}
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters());
auto node_list = TopoSort(primal_fg->output());
for (auto node : node_list) {
if (node->isa<CNode>()) {
auto c_node = node->cast<CNodePtr>();
auto out = c_node->abstract()->GetValueTrack();
ValuePtrList args;
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
args.push_back(c_node->input(i)->abstract()->GetValueTrack());
}
GradPynativeOp(k_pynative_cell, c_node, args, out);
}
}
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false);
return bprop_fg;
}
};
TEST_F(TestKPynative, test_simple_add) {
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
resource->manager()->KeepRoots({primal_fg});
ExportIR(primal_fg->ToString() + ".dat", "", primal_fg);
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg);
}
} // namespace ad
} // namespace mindspore