!2148 fix hook and bprop debug issue in pynative

Merge pull request !2148 from wangqiuliang/fix-hook-bprop-issue
This commit is contained in:
mindspore-ci-bot 2020-06-18 10:44:45 +08:00 committed by Gitee
commit 9ba6f61d01
18 changed files with 158 additions and 74 deletions

View File

@ -113,6 +113,24 @@ def bool_or(x, y):
"""Implement `bool_or`."""
return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
def make_list(*xs):
"""Implement `make_list`."""

View File

@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}
namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple";
@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return true;
}
FuncGraphPtr ConvertToBpropCut(py::object obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, "bprop");
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
return false;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}

View File

@ -51,6 +51,7 @@ void ClearObjectCache();
} // namespace data_converter
ClassPtr ParseDataClass(const py::object &cls_obj);
FuncGraphPtr ConvertToBpropCut(const py::object &obj);
void CleanDataClassToClassMap();

View File

@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
const int MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace

View File

@ -45,7 +45,7 @@ enum PynativeStatusCode {
PYNATIVE_UNKNOWN_STATE = 0XFF
};
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM };
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
PrimitivePyPtr py_primitive;

View File

@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
return obj_tuple;
}
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
if (py::hasattr(args[i], "__parameter__")) {
input_mask[i] = true;
} else {
input_mask[i] = false;
}
py_args[i] = GetTupleObj(args[i]);
}
auto signature = prim->signatures();
@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return;
return input_mask;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
continue;
}
}
return input_mask;
}
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]);
if (input_value->isa<tensor::Tensor>()) {
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
} else {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Four args are needed by RunOp";
MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr;
}
auto op_exec_info = std::make_shared<OpExecInfo>();
@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
size_t input_num = a.size();
op_exec_info->op_inputs = py::tuple(input_num);
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr;
@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return result;
}
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.size() == 0) {
return nullptr;
}
std::vector<AnfNodePtr> inputs;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto prim = op_exec_info->py_primitive;
inputs.push_back(NewValueNode(prim));
py::tuple op_masks = args[PY_INPUT_MASK];
py::tuple op_masks = op_exec_info->inputs_mask;
py::list op_args = args[PY_INPUTS];
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < op_args.size(); i++) {
@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
return err_ret;
}
auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (node != nullptr) {
node->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
// cell construct return x, y
if (py::isinstance<py::tuple>(out)) {
std::vector<AnfNodePtr> args;
@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}
auto output_node = GetObjNode(out);
AnfNodePtr output_node;
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
output_node = graph_info_map_[curr_g_].param_map[out_id];
} else {
output_node = GetObjNode(out);
}
curr_g_->set_output(output_node);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(curr_g_));
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
MS_LOG(DEBUG) << "Use cell custom bprop function.";
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
if (bprop_graph != nullptr) {
(void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (curr_g_ != top_g_) {
Popp();

View File

@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args);
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
void ClearPyNativeSession();
@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);
void Pushp();

View File

@ -16,6 +16,7 @@
"""Registry the relation."""
from collections import UserDict
from .. import context
class Registry(UserDict):
@ -27,9 +28,16 @@ class Registry(UserDict):
def get(self, obj_str):
"""Get the value by str."""
if isinstance(obj_str, str):
if not isinstance(obj_str, str):
raise TypeError("key for tensor registry must be string.")
if context.get_context("enable_ge"):
def wrap(*args):
new_args = list(args)
new_args.append(obj_str)
return self["vm_compare"](*new_args)
obj = wrap
else:
obj = self[obj_str]
return obj
tensor_operator_registry = Registry()

View File

@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor
from .._checkparam import check_type, check_typename
from . import dtype as mstype
from .. import context
from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor']
@ -76,17 +75,19 @@ class Tensor(Tensor_):
return out
def __eq__(self, other):
if not isinstance(other, Tensor):
if not isinstance(other, (int, float, Tensor)):
return False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend.
if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_:
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other):
if not isinstance(other, Tensor):
if not isinstance(other, (int, float, Tensor)):
return True
# bool type is not supported for `NotEqual` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
return tensor_operator_registry.get('__ne__')(self, other)
def __hash__(self):
@ -105,7 +106,7 @@ class Tensor(Tensor_):
return out
def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(other, self)
out = tensor_operator_registry.get('__add__')(self, other)
return out
def __imul__(self, other):
@ -113,15 +114,15 @@ class Tensor(Tensor_):
return out
def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(other, self)
out = tensor_operator_registry.get('__mul__')(self, other)
return out
def __truediv__(self, other):
out = tensor_operator_registry.get('__div__')(self, other)
out = tensor_operator_registry.get('__truediv__')(self, other)
return out
def __rtruediv__(self, other):
out = tensor_operator_registry.get('__div__')(other, self)
out = tensor_operator_registry.get('__truediv__')(other, self)
return out
def __sub__(self, other):
@ -160,7 +161,7 @@ class Tensor(Tensor_):
return out
def __len__(self):
out = tensor_operator_registry.get('__shape__')(self)
out = tensor_operator_registry.get('shape')(self)
if not out:
return 1
return out[0]

View File

@ -819,4 +819,4 @@ class Cell:
"""
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True
self.enable_hook = True

View File

@ -140,6 +140,11 @@ class SequentialCell(Cell):
def __len__(self):
return len(self._cells)
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
def construct(self, input_data):
for cell in self.cell_list:
input_data = cell(input_data)
@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
self._cells[str(len(self))] = cell
return self
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
def construct(self, *inputs):
raise NotImplementedError

View File

@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn:
if self.get_by_list:
if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug:
if context.get_context("mode") == context.GRAPH_MODE:
@ms_function(obj=fn)
def after_grad(*args):
return grad_(fn, weights)(*args)

View File

@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .._extends import builtin_operations as BP
typeof = Primitive('typeof')
hastype = Primitive('hastype')
@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
tensor_operator_registry.register('__truediv__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)

View File

@ -863,6 +863,8 @@ class TupleToArray(PrimitiveWithInfer):
args = list()
if isinstance(x, range):
args.append(tuple(x))
else:
args.append(x)
return _run_op(self, self.name, args)

View File

@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
op_mask = [0] * len(args)
op_inputs = []
for i, arg in enumerate(args):
if hasattr(arg, '__parameter__'):
op_mask[i] = 1
op_inputs.append(arg)
output = real_run_op(obj, op_name, args, tuple(op_mask))
output = real_run_op(obj, op_name, args)
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:

View File

@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
py::none py_none;
py::tuple op_mask = py::make_tuple(0, 1);
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask));
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs));
}
TEST_F(TestPynativeExecute, TestRunOpInVM) {
@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
py::none py_none;
auto op_exec_info_ptr = ConstructOpExecInfo();
py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name,
op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask));
op_exec_info_ptr->op_inputs));
if (outputs.size() == 0) {
FAIL();
} else {

View File

@ -452,5 +452,5 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
with pytest.raises(TypeError):
with pytest.raises(ValueError):
res = x * (2, 3)

View File

@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
cell_hook_done = False
var_hook_done = False
cell_bprop_done = False
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
@ -32,15 +35,35 @@ def weight_variable():
def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id)
global cell_hook_done
cell_hook_done = True
assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
def var_hook_function(grad_out):
print("grad:", grad_out)
global var_hook_done
var_hook_done = True
assert (grad_out[0].asnumpy().shape == (32, 120))
class Block(nn.Cell):
def __init__(self):
super(Block, self).__init__()
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(x)
return x
def bprop(self, x, out, dout):
global cell_bprop_done
cell_bprop_done = True
grad = out.asnumpy() * dout.asnumpy()
grad = Tensor(grad)
return (grad,)
class LeNet5(nn.Cell):
"""
Lenet network
@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.conv2.register_backward_hook(cell_hook_function)
self.block = Block()
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.block(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
@ -110,6 +134,9 @@ def test_hook():
loss_output = criterion(output, label)
grads = train_network(input_data, label)
success = optimizer(grads)
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
print(loss_output.asnumpy().shape)