forked from mindspore-Ecosystem/mindspore
!1160 add backward hook and custom bprop in pynative mode
Merge pull request !1160 from wangqiuliang/add-backward-hook-in-pynative-mode
This commit is contained in:
commit
0e665616f0
|
@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None):
|
|||
method_name = parse_method
|
||||
else:
|
||||
if isinstance(obj, nn.Cell):
|
||||
method_name = "construct"
|
||||
if obj.enable_hook:
|
||||
method_name = "_hook_construct"
|
||||
else:
|
||||
method_name = "construct"
|
||||
if method_name is not None:
|
||||
if hasattr(obj, method_name):
|
||||
method = getattr(obj, method_name)
|
||||
|
|
|
@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
|||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -31,8 +30,6 @@
|
|||
#include "ir/signature.h"
|
||||
#include "parallel/ops_info/operator_info.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
class PrimitivePy : public Primitive {
|
||||
public:
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
#include <tuple>
|
||||
|
||||
#include "ir/dtype/type.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
// Supported meta type
|
||||
|
@ -73,6 +76,9 @@ class Primitive : public Named {
|
|||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
|
||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||
|
@ -103,6 +109,7 @@ class Primitive : public Named {
|
|||
|
||||
private:
|
||||
std::string instance_name_;
|
||||
py::function hook_;
|
||||
bool is_base_;
|
||||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
|
|
|
@ -213,6 +213,7 @@ const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
|||
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
|
||||
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
|
||||
// Other miscellaneous
|
||||
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
|
@ -226,6 +227,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
|||
const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
|
||||
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
||||
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
||||
const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
|
|
|
@ -218,6 +218,7 @@ extern const PrimitivePtr kPrimReluV2;
|
|||
extern const PrimitivePtr kPrimActivation;
|
||||
extern const PrimitivePtr kPrimZerosLikeTensor;
|
||||
extern const PrimitivePtr kPrimFakeBprop;
|
||||
extern const PrimitivePtr kPrimBpropCut;
|
||||
|
||||
// Other Miscellaneous
|
||||
extern const PrimitivePtr kPrimIdentity;
|
||||
|
@ -232,6 +233,7 @@ extern const PrimitivePtr kPrimGetRefKey;
|
|||
extern const PrimitivePtr kPrimGetRefValue;
|
||||
extern const PrimitivePtr kPrimGetRefOrigin;
|
||||
extern const PrimitivePtr kPrimInsertGradientOf;
|
||||
extern const PrimitivePtr kPrimHookBackward;
|
||||
extern const PrimitivePtr kPrimPrintShapeType;
|
||||
extern const PrimitivePtr kPrimPrint;
|
||||
extern const PrimitivePtr kPrimSameTypeShape;
|
||||
|
|
|
@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
AbstractBasePtrList args_list;
|
||||
for (size_t i = 0; i < args_spec_list.size() - 2; i++) {
|
||||
args_list.push_back(args_spec_list[i]->Broaden());
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(args_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors(x, gamma, beta).
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "operator/ops.h"
|
||||
#include "operator/composite/composite.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "./common.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -125,6 +125,7 @@ class KPrim {
|
|||
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||
// Given a bprop rule, do the K mapping.
|
||||
template <typename T>
|
||||
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g);
|
||||
|
|
|
@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
}
|
||||
|
||||
bool is_faked_bprop = false;
|
||||
auto bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = FakeBprop(value_node, resources);
|
||||
is_faked_bprop = true;
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
|
||||
bprop_fg = BpropCut(value_node, resources);
|
||||
} else {
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = FakeBprop(value_node, resources);
|
||||
is_faked_bprop = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto expanded_fg = BpropToK(prim, bprop_fg);
|
||||
|
@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
|
|||
return expanded_fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto &node_users = resources->manager()->node_users();
|
||||
|
||||
auto &users = node_users[value_node];
|
||||
auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
|
||||
return IsPrimitiveCNode(user.first, prim);
|
||||
});
|
||||
if (cnode == users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to find cnode.";
|
||||
}
|
||||
auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
|
||||
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto bprop_cut = std::make_shared<Primitive>("bprop_cut");
|
||||
bprop_cut->set_hook(prim->hook());
|
||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (cell_id != "") {
|
||||
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
|
||||
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
|
||||
}
|
||||
|
||||
outputs.push_back(NewValueNode(bprop_cut));
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
auto param = func_graph->add_parameter();
|
||||
outputs.push_back(param);
|
||||
}
|
||||
auto p1 = func_graph->add_parameter();
|
||||
auto p2 = func_graph->add_parameter();
|
||||
outputs.push_back(p1);
|
||||
outputs.push_back(p2);
|
||||
|
||||
func_graph->set_output(func_graph->NewCNode(outputs));
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
||||
auto prim = value_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
|
|
@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
|
||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
||||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
special_op_eliminate_ =
|
||||
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
||||
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||
|
||||
|
|
|
@ -35,11 +35,13 @@ class SpecialOpEliminater {
|
|||
public:
|
||||
SpecialOpEliminater()
|
||||
: insert_gradient_of_(prim::kPrimInsertGradientOf),
|
||||
hook_backward_(prim::kPrimHookBackward),
|
||||
print_shape_type_(prim::kPrimPrintShapeType),
|
||||
get_ref_value_(prim::kPrimGetRefValue),
|
||||
mirror_(prim::kPrimMirror),
|
||||
virtual_div_(prim::kPrimVirtualDiv) {
|
||||
eliminaters_.emplace_back(insert_gradient_of_);
|
||||
eliminaters_.emplace_back(hook_backward_);
|
||||
eliminaters_.emplace_back(print_shape_type_);
|
||||
eliminaters_.emplace_back(get_ref_value_);
|
||||
eliminaters_.emplace_back(mirror_);
|
||||
|
@ -59,7 +61,7 @@ class SpecialOpEliminater {
|
|||
}
|
||||
|
||||
private:
|
||||
PrimEliminater insert_gradient_of_, print_shape_type_, get_ref_value_, mirror_, virtual_div_;
|
||||
PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_;
|
||||
std::vector<TransformFuncType> eliminaters_{};
|
||||
};
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "operator/composite/composite.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr ConvertToBpropCut(py::object obj) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_key = results[0];
|
||||
py::function bprop_func = py::getattr(obj, "bprop");
|
||||
|
||||
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto fake_bprop = std::make_shared<Primitive>("bprop_cut");
|
||||
fake_bprop->set_hook(bprop_func);
|
||||
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
|
||||
outputs.push_back(NewValueNode(fake_bprop));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
auto param = bprop_graph->add_parameter();
|
||||
outputs.push_back(param);
|
||||
}
|
||||
auto p1 = bprop_graph->add_parameter();
|
||||
auto p2 = bprop_graph->add_parameter();
|
||||
outputs.push_back(p1);
|
||||
outputs.push_back(p2);
|
||||
|
||||
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
|
||||
data_converter::SetObjGraphValue(obj_key, bprop_graph);
|
||||
return bprop_graph;
|
||||
}
|
||||
|
||||
bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
||||
auto obj_type = data_converter::GetObjType(obj);
|
||||
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
|
||||
|
@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
|||
}
|
||||
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
|
||||
if (py::hasattr(obj, "bprop")) {
|
||||
FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
|
||||
FuncGraphPtr bprop_graph = nullptr;
|
||||
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
|
||||
if (enable_bprop_debug) {
|
||||
bprop_graph = ConvertToBpropCut(obj);
|
||||
} else {
|
||||
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
|
||||
}
|
||||
if (bprop_graph != nullptr) {
|
||||
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
|
||||
|
|
|
@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
|
||||
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
||||
|
|
|
@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
|
|||
result.outputs = outputs;
|
||||
result.graph_id = kInvalidGraphId;
|
||||
auto graph_id = sess_->CompileGraph(lst, outputs);
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
sess_->BuildGraph(graph_id);
|
||||
}
|
||||
if (MsContext::GetInstance()->precompile_only()) {
|
||||
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
||||
return result;
|
||||
|
|
|
@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
|
|||
using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
|
||||
|
||||
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
||||
prim::kPrimMakeTuple};
|
||||
prim::kPrimMakeTuple, prim::kPrimBpropCut};
|
||||
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
|
||||
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch};
|
||||
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
||||
prim::kPrimBpropCut};
|
||||
return ms_nonlinear_ops;
|
||||
}
|
||||
|
||||
|
@ -646,8 +647,13 @@ BackendPtr CreateBackend() {
|
|||
auto backend = std::make_shared<MsBackend>(name, target, device_id);
|
||||
std::string device_target = MsContext::GetInstance()->device_target();
|
||||
if (device_target == kAscendDevice) {
|
||||
backend->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
backend->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_is_multi_graph_sink(false);
|
||||
} else {
|
||||
backend->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
}
|
||||
}
|
||||
return backend;
|
||||
}
|
||||
|
|
|
@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
|
|||
VectorRef tuple;
|
||||
auto prim = utils::cast<PrimitivePtr>(args[0]);
|
||||
for (size_t i = 1; i < args.size(); ++i) {
|
||||
auto index = utils::cast<int>(args[1]);
|
||||
auto index = utils::cast<int>(args[i]);
|
||||
tuple.push_back(Ref(index));
|
||||
}
|
||||
|
||||
auto outs = RunOperation(prim, tuple);
|
||||
Push(outs);
|
||||
if (prim->name() == "bprop_cut") {
|
||||
auto outs = RunHook(prim, tuple);
|
||||
Push(outs);
|
||||
} else {
|
||||
auto outs = RunOperation(prim, tuple);
|
||||
Push(outs);
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
|
||||
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
py::tuple py_args = py::tuple(args.size());
|
||||
MS_LOG(DEBUG) << "input for operation:";
|
||||
size_t i = 0;
|
||||
for (auto &arg : args) {
|
||||
py_args[i] = BaseRefToPyData(arg);
|
||||
MS_LOG(DEBUG) << "arg: " << i << ":";
|
||||
i++;
|
||||
}
|
||||
py::object obj;
|
||||
bool is_bprop = prim->HasAttr("bprop");
|
||||
if (is_bprop) {
|
||||
py::function fn_bprop = prim->hook();
|
||||
obj = fn_bprop(*py_args);
|
||||
return obj;
|
||||
}
|
||||
bool is_cell = prim->HasAttr("cell_hook");
|
||||
if (is_cell) {
|
||||
std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
|
||||
py::tuple hook_args = py::tuple(3);
|
||||
hook_args[0] = cell_id;
|
||||
hook_args[1] = _hook_grad[cell_id];
|
||||
hook_args[2] = py_args[2];
|
||||
py::function fn_hook = prim->hook();
|
||||
obj = fn_hook(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
_hook_grad.erase(cell_id);
|
||||
} else {
|
||||
_hook_grad[cell_id] = py_args[2];
|
||||
obj = py_args[2];
|
||||
}
|
||||
} else {
|
||||
py::function fn_hook = prim->hook();
|
||||
obj = fn_hook(py_args[2]);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
}
|
||||
}
|
||||
obj = py::make_tuple(obj);
|
||||
return obj;
|
||||
}
|
||||
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -115,6 +115,7 @@ class FinalVM {
|
|||
void InstPushPrim(const VectorRef &args);
|
||||
void InstSwitchReturn(const VectorRef &args);
|
||||
void set_insts(const InstSet &value) { insts_ = value; }
|
||||
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args);
|
||||
|
||||
protected:
|
||||
BaseRef Ref(int i);
|
||||
|
@ -156,6 +157,7 @@ class FinalVM {
|
|||
{Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
|
||||
{Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
|
||||
};
|
||||
std::map<std::string, py::object> _hook_grad;
|
||||
};
|
||||
|
||||
using FinalVMPtr = std::shared_ptr<FinalVM>;
|
||||
|
|
|
@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular
|
|||
from ..common.parameter import Parameter, ParameterTuple
|
||||
from .._c_expression import init_backend
|
||||
from ..ops.primitive import Primitive
|
||||
from ..ops.operations import HookBackward
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
|
@ -75,6 +76,9 @@ class Cell:
|
|||
self._parallel_inputs_run = None
|
||||
if flags:
|
||||
self.add_flags(**flags)
|
||||
self._backward_hook = None
|
||||
self._enable_hook = False
|
||||
self._bprop_debug = False
|
||||
|
||||
@property
|
||||
def create_time(self):
|
||||
|
@ -91,6 +95,16 @@ class Cell:
|
|||
"""
|
||||
return self._param_prefix
|
||||
|
||||
@property
|
||||
def bprop_debug(self):
|
||||
return self._bprop_debug
|
||||
|
||||
@bprop_debug.setter
|
||||
def bprop_debug(self, value):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("'bprop debug' value must be bool type.")
|
||||
self._bprop_debug = value
|
||||
|
||||
def update_cell_prefix(self):
|
||||
"""
|
||||
Update the all child cells' self.param_prefix.
|
||||
|
@ -728,3 +742,25 @@ class Cell:
|
|||
self._auto_parallel_mode = True
|
||||
self.add_flags(auto_parallel=True)
|
||||
self._get_construct_inputs_number_and_name()
|
||||
|
||||
def _hook_construct(self, inputs):
|
||||
"""Hook construct method to replace original construct method when hook function enabled."""
|
||||
inputs = self._backward_hook(inputs)
|
||||
inputs = self.construct(inputs)
|
||||
outputs = self._backward_hook(inputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def enable_hook(self):
|
||||
"""Whether the cell register hook function"""
|
||||
return self._enable_hook
|
||||
|
||||
def register_backward_hook(self, fn):
|
||||
"""
|
||||
Set the cell backward hook function.
|
||||
|
||||
Args:
|
||||
fn (function): Specifies the hook function with grad as input.
|
||||
"""
|
||||
self._backward_hook = HookBackward(fn, str(id(self)))
|
||||
self._enable_hook = True
|
||||
|
|
|
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary,
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast
|
||||
|
@ -155,6 +155,7 @@ __all__ = [
|
|||
'HistogramSummary',
|
||||
"Print",
|
||||
'InsertGradientOf',
|
||||
'HookBackward',
|
||||
'InvertPermutation',
|
||||
'Shape',
|
||||
'DropoutDoMask',
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""debug_ops"""
|
||||
from types import FunctionType
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
|
@ -241,6 +242,65 @@ class InsertGradientOf(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class HookBackward(PrimitiveWithInfer):
|
||||
"""
|
||||
Used as tag to hook gradient in intermediate variables.
|
||||
|
||||
Note:
|
||||
The hook function should have one input of gradient of the variable.
|
||||
hook function will be executed in python environment, while callback
|
||||
of InsertGradientOf will be parsed and added to the graph.
|
||||
|
||||
Args:
|
||||
hook_fn (Function): Python function. hook function.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The variable to hook.
|
||||
|
||||
Examples:
|
||||
>>> def hook_fn(grad_out):
|
||||
>>> print(grad_out)
|
||||
>>>
|
||||
>>> hook = P.HookBackward(hook_fn)
|
||||
>>>
|
||||
>>> def hook_test(x, y):
|
||||
>>> z = x * y
|
||||
>>> z = hook(z)
|
||||
>>> z = z * y
|
||||
>>> return z
|
||||
>>>
|
||||
>>> def backward(x, y):
|
||||
>>> return C.grad_all(hook_test)(x, y)
|
||||
>>>
|
||||
>>> backward(1, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, hook_fn, cell_id=""):
|
||||
super(HookBackward, self).__init__(self.__class__.__name__)
|
||||
self.add_prim_attr("cell_id", cell_id)
|
||||
self.init_attrs["cell_id"] = cell_id
|
||||
if not isinstance(hook_fn, FunctionType):
|
||||
raise TypeError("Hook function should be python function type.")
|
||||
self.register_hook(hook_fn)
|
||||
self.cell_id = cell_id
|
||||
|
||||
def __call__(self, *inputs):
|
||||
"""run in PyNative mode."""
|
||||
if len(inputs) == 1:
|
||||
return inputs[0]
|
||||
return inputs
|
||||
|
||||
def infer_shape(self, *inputs_shape):
|
||||
if len(inputs_shape) == 1:
|
||||
return inputs_shape[0]
|
||||
return inputs_shape
|
||||
|
||||
def infer_dtype(self, *inputs_type):
|
||||
if len(inputs_type) == 1:
|
||||
return inputs_type[0]
|
||||
return inputs_type
|
||||
|
||||
|
||||
class Print(PrimitiveWithInfer):
|
||||
"""
|
||||
Output tensor or string to stdout.
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import context, Tensor, ParameterTuple
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
def cell_hook_function(cell_id, grad_input, grad_output):
|
||||
print(cell_id)
|
||||
assert(grad_output.asnumpy().shape == (32, 6, 14, 14))
|
||||
assert(grad_input.asnumpy().shape == (32, 16, 10, 10))
|
||||
|
||||
|
||||
def var_hook_function(grad_out):
|
||||
print("grad:", grad_out)
|
||||
assert(grad_out.asnumpy().shape == (32, 120))
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
Args:
|
||||
num_class (int): Num classes. Default: 10.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor
|
||||
|
||||
Examples:
|
||||
>>> LeNet(num_class=10)
|
||||
"""
|
||||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.batch_size = 32
|
||||
self.conv1 = conv(1, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.conv2.register_backward_hook(cell_hook_function)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.hook = P.HookBackward(var_hook_function)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.reshape(x, (self.batch_size, -1))
|
||||
x = self.fc1(x)
|
||||
x = self.hook(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
""" GradWrap definition """
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||
|
||||
def construct(self, x, label):
|
||||
weights = self.weights
|
||||
return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
|
||||
|
||||
def test_hook():
|
||||
net = LeNet5()
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = GradWrap(net_with_criterion)
|
||||
train_network.set_train()
|
||||
|
||||
input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
|
||||
output = net(Tensor(input_data))
|
||||
loss_output = criterion(output, label)
|
||||
grads = train_network(input_data, label)
|
||||
success = optimizer(grads)
|
||||
print(loss_output.asnumpy().shape)
|
||||
|
||||
|
||||
|
||||
|
||||
class MulAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAdd, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
assert(x == 1)
|
||||
assert(y == 2)
|
||||
assert(out == 4)
|
||||
assert(dout == 1)
|
||||
return 3 * dout, 2 * y
|
||||
|
||||
def test_custom_bprop():
|
||||
mul_add = MulAdd()
|
||||
mul_add.bprop_debug = True
|
||||
assert C.grad_all(mul_add)(1, 2) == (3, 4)
|
Loading…
Reference in New Issue