!33497 Support custom cell bprop when using Parameter in primal graph

Merge pull request !33497 from YuJianfeng/custom_bprop
This commit is contained in:
i-robot 2022-04-27 12:30:02 +00:00 committed by Gitee
commit 406eaf37b1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 622 additions and 31 deletions

View File

@ -539,9 +539,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
FuncGraphPtr bprop_graph = bprop->second.func_graph();
resources_->manager()->AddFuncGraph(bprop_graph);
if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
MS_LOG(EXCEPTION) << "The Cell with user defined 'bprop' function in scope " << primal->output()->scope()->name()
<< " does not support Parameter data type.\n"
(void)parse::ResolveFuncGraph(bprop_graph, resources_);
if (!bprop_graph->free_variables_nodes().empty()) {
MS_LOG(EXCEPTION) << "The user defined 'bprop' function in scope " << primal->output()->scope()->name()
<< " does not support using Parameter.\n"
<< trace::GetDebugInfo(bprop_graph->debug_info());
}
bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);

View File

@ -53,8 +53,9 @@ KPrim g_k_prims;
namespace {
constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
constexpr char serializable_bprop_ops[] = "serializable_bprop_ops";
constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir";
constexpr char kSerializableBpropOps[] = "serializable_bprop_ops";
constexpr char kBpropMindirModule[] = "mindspore.ops.bprop_mindir";
constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
#ifndef _WIN32
std::string GetBpropDir() {
@ -86,8 +87,8 @@ mindspore::HashSet<std::string> GetSerializableBpropList() {
if (!BpropMindirDirExists()) {
return serializable_bprop_list;
}
py::module mod = py::module::import(bprop_mindir_module);
py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops);
py::module mod = py::module::import(kBpropMindirModule);
py::object serializable_bprop_ops_attr = mod.attr(kSerializableBpropOps);
if (!py::isinstance<py::list>(serializable_bprop_ops_attr)) {
MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list.";
return serializable_bprop_list;
@ -614,8 +615,9 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
}
AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
// current_primal_fg may have extra parameters like u_monad, io_monad
std::vector<AnfNodePtr> extra_args;
// The primal fg may have extra parameters from lifted fv or u_monad and io_monad.
std::vector<AnfNodePtr> extra_lifted_args;
std::vector<AnfNodePtr> extra_monad_args;
// caller had checked size() - 2 is greater than 0.
auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
@ -623,6 +625,18 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
"Insert it. Extra parameters size: "
<< current_primal_fg_param_size - bprop_fg_param_size;
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
for (size_t i = 0; i < current_primal_fg_param_size; ++i) {
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg->parameters()[i]);
MS_EXCEPTION_IF_NULL(primal_parameter);
auto lifted = primal_parameter->user_data<bool>(kLiftedUserDataKey);
if (lifted == nullptr || !*lifted) {
break;
}
extra_lifted_args.push_back(
bprop_fg->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), primal_parameter}));
++bprop_fg_param_size;
}
for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
const auto &primal_node = current_primal_fg->parameters()[i];
AnfNodePtr extra_node;
@ -639,7 +653,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
"as the 'out' and 'dout'.\n"
<< trace::GetDebugInfo(bprop_fg->debug_info());
}
extra_args.push_back(extra_node);
extra_monad_args.push_back(extra_node);
MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
}
}
@ -652,9 +666,13 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
args.push_back(NewEnviron(bprop_fg));
// The lifted parameters are put in front.
if (!extra_lifted_args.empty()) {
(void)args.insert(args.end(), extra_lifted_args.cbegin(), extra_lifted_args.cend());
}
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
if (!extra_args.empty()) {
args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
if (!extra_monad_args.empty()) {
(void)args.insert(args.end(), extra_monad_args.cbegin(), extra_monad_args.cend());
}
return NewCNode(args, bprop_fg);
}
@ -664,9 +682,14 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
constexpr char python_ops[] = "_tuple_add";
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
if (!extra_args.empty()) {
extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
auto extra_tuple = NewCNode(extra_args, bprop_fg);
if (!extra_lifted_args.empty()) {
(void)extra_lifted_args.insert(extra_lifted_args.begin(), NewValueNode(prim::kPrimMakeTuple));
auto extra_tuple = NewCNode(extra_lifted_args, bprop_fg);
tuple_env = NewCNode({tuple_add_ops, tuple_env, extra_tuple}, bprop_fg);
}
if (!extra_monad_args.empty()) {
(void)extra_monad_args.insert(extra_monad_args.begin(), NewValueNode(prim::kPrimMakeTuple));
auto extra_tuple = NewCNode(extra_monad_args, bprop_fg);
auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
}
@ -676,6 +699,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
std::vector<AnfNodePtr> *transf_args) {
MS_EXCEPTION_IF_NULL(mng);
// bprop_fg has been checked in caller
// transform except the last 2 parameters: out, dout.
const size_t last_parameter_sizes = 2;
@ -694,7 +718,6 @@ static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphP
void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
const PrimitivePtr &primitive, const FuncGraphPtr &outer,
std::vector<AnfNodePtr> *transf_args) {
MS_EXCEPTION_IF_NULL(mng);
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
// Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
auto effect_info = GetPrimEffectInfo(primitive);
@ -714,12 +737,24 @@ template <typename T>
void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
const T &current_primal_fg, const FuncGraphPtr &outer,
std::vector<AnfNodePtr> *transf_args) {
MS_EXCEPTION_IF_NULL(mng);
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
constexpr size_t need_filter_size = 2;
auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
// current_primal_fg may have extra parameters after AutoMonad
const auto &current_primal_fg_params = current_primal_fg->parameters();
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg_params[i]);
MS_EXCEPTION_IF_NULL(primal_parameter);
auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
if (lifted == nullptr || !*lifted) {
break;
}
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal_parameter->debug_info()));
auto transf_p = outer->add_parameter();
transf_args->push_back(transf_p);
++bprop_fg_param_size;
}
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
// Current primal fg may have extra parameters after AutoMonad
if (bprop_fg_param_size < current_primal_fg_params.size()) {
for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
auto p = current_primal_fg_params[i];

View File

@ -2703,11 +2703,20 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
return;
}
MS_LOG(DEBUG) << "Do grad for custom bprop";
size_t par_number = py::tuple(python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, the 'Parameter' data type is not supported in the net.";
}
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
py::object code_obj = py::getattr(bprop_func, "__code__");
// When the co_names is empty, we will still get a tuple which is empty.
auto co_names = py::getattr(code_obj, "co_names").cast<py::tuple>();
for (auto name : co_names) {
if (!py::hasattr(cell, name)) {
continue;
}
auto var = py::getattr(cell, name);
if (py::hasattr(var, "__parameter__") && py::isinstance<tensor::MetaTensor>(var)) {
MS_LOG(EXCEPTION) << "The user defined 'bprop' function does not support using Parameter.";
}
}
auto bprop_func_cellid = GetId(bprop_func);
bprop_cell_list_.emplace_back(bprop_func_cellid);
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
@ -2721,7 +2730,6 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
py::object code_obj = py::getattr(bprop_func, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";

View File

@ -300,7 +300,9 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
}
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
auto &fg_params = repl_func_graph_params_[func_graph];
(void)fg_params.emplace_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
auto fv_parameter = AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var));
fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
(void)fg_params.emplace_back(fv_parameter);
}
}
@ -313,6 +315,11 @@ void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
param->set_default_param(old_param->default_param());
}
param->set_name(old_param->name());
constexpr char lifted_user_data_key[] = "lifted_from_fv";
auto lifted = param->user_data<bool>(lifted_user_data_key);
if (lifted != nullptr && *lifted) {
param->set_user_data<bool>(lifted_user_data_key, std::make_shared<bool>(true));
}
}
}

View File

@ -26,9 +26,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_all = C.GradOperation(get_all=True)
@ -40,6 +37,7 @@ class MulAdd(nn.Cell):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
return 2 * dout, 2 * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -59,6 +57,7 @@ class InlineMulADD(nn.Cell):
def construct(self, x, y):
return self.mul_add(x, y) + x + self.param * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -82,6 +81,7 @@ class WithParameter(nn.Cell):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
return self.param1 * self.param2 * dout, 2 * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -95,6 +95,7 @@ class WithNoBprop(nn.Cell):
def construct(self, x, y):
return 2 * x + y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -104,6 +105,7 @@ def test_with_no_bprop():
y = Tensor(2, dtype=ms.int32)
assert grad_all(with_no_bprop)(x, y) == (2, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -142,6 +144,7 @@ def test_grad_in_bprop_1():
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -183,6 +186,7 @@ def test_grad_in_bprop_2():
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -236,6 +240,7 @@ class OneInputBprop(nn.Cell):
def bprop(self, x, out, dout):
return (5 * x,)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -263,6 +268,7 @@ class InlineBpropTwoInput(nn.Cell):
grads = grad_all(self.f)(x, y)
return grads[0] * 2, grads[1] * 2
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -324,6 +330,7 @@ class InlineMutilTwoInputParameterCell(nn.Cell):
output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -347,20 +354,24 @@ class MulAddWithParam(nn.Cell):
def construct(self, x):
return self.mul_add(self.param, x)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_refkey_bprop():
grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x):
weights = self.weights
grads = grad_by_list(self.network, weights)(x)
return grads
network = GradWrap(MulAddWithParam())
input_data = Tensor(np.array([2, 2], np.float32))
grads = network(input_data)
@ -375,6 +386,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
def bprop(self, x, y, out, dout):
return (2 * dout,)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -392,6 +404,7 @@ class MulAddWithWrongOutputType(nn.Cell):
def bprop(self, x, y, out, dout):
return 2 * dout, 2
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -413,6 +426,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
def bprop(self, x, y, out, dout):
return 2, self.ones
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@ -421,3 +435,328 @@ def test_grad_mul_add_with_wrong_output_shape():
mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError):
grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_forward_with_parameter():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the forward net using Parameter.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
assert np.allclose(out[0].asnumpy(), expect_dx)
assert np.allclose(out[1].asnumpy(), expect_dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_forward_with_parameter_in_sub_cell():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
assert np.allclose(out[0].asnumpy(), expect_dx)
assert np.allclose(out[1].asnumpy(), expect_dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_forward_with_parameter_in_sub_cell_get_by_list():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
Expectation: Get the correct gradients.
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.params = ParameterTuple(net.trainable_params())
self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
def construct(self, x, y):
grad_f = self.grad_op(self.net, self.params)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
expect_dz = np.array([0.0]).astype(np.float32)
assert np.allclose(out[0][0].asnumpy(), expect_dx)
assert np.allclose(out[0][1].asnumpy(), expect_dy)
assert np.allclose(out[1][0].asnumpy(), expect_dz)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_forward_with_parameter():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the forward net using Parameter.
Expectation: Get the correct gradients.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
assert np.allclose(out[0].asnumpy(), expect_dx)
assert np.allclose(out[1].asnumpy(), expect_dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_forward_with_parameter_in_sub_cell():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
Expectation: Get the correct gradients.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
assert np.allclose(out[0].asnumpy(), expect_dx)
assert np.allclose(out[1].asnumpy(), expect_dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_forward_with_parameter_in_sub_cell_get_by_list():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
Expectation: Get the correct gradients.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + x
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.params = ParameterTuple(net.trainable_params())
self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
def construct(self, x, y):
grad_f = self.grad_op(self.net, self.params)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
out = GradNet(Net())(x, y)
expect_dx = np.array([[1.0, 1.2, 0.8],
[2.4, 2.6, 2.2]]).astype(np.float32)
expect_dy = np.array([[0.02, 0.6, 2.2],
[0.2, 0.4, 2.6],
[4.2, 2.4, 6.6]]).astype(np.float32)
expect_dz = np.array([0.0]).astype(np.float32)
assert np.allclose(out[0][0].asnumpy(), expect_dx)
assert np.allclose(out[0][1].asnumpy(), expect_dy)
assert np.allclose(out[1][0].asnumpy(), expect_dz)

View File

@ -16,6 +16,7 @@ import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
@ -25,6 +26,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
grad_all = C.GradOperation(get_all=True)
grad_by_list = C.GradOperation(get_by_list=True)
class CropAndResizeNet(nn.Cell):
def __init__(self, crop_size):
super(CropAndResizeNet, self).__init__()
@ -102,6 +104,7 @@ class BPropOperatatorNet(nn.Cell):
x = self.bn(x)
return x
def test_user_defined_bprop_with_u():
net = BPropOperatatorNet(mul_size=(128, 96))
grad_net = TestUserDefinedBpropGradNet(net)
@ -153,6 +156,7 @@ def test_second_grad_with_j_primitive():
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
def test_ad_fv_cnode_order():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
@ -164,6 +168,7 @@ def test_ad_fv_cnode_order():
return xay
return second_level() + xay
return first_level()
input_x = Tensor(np.array([1.0], dtype=np.float32))
@ -178,6 +183,7 @@ def test_ad_fv_cnode_order():
# True and False branch of switch have different number of parameters.
def test_if_branch_with_different_params():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -215,6 +221,7 @@ def test_if_branch_with_different_params():
# because weight1 in Net may use old_parameter other than replicated one.
def test_limit_lift_fv_scope():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -276,6 +283,7 @@ def test_same_primal_used_by_multi_j():
def test_same_primal_used_by_multi_j_with_monad1():
context.set_context(mode=context.GRAPH_MODE)
class AdamNet(nn.Cell):
def __init__(self, var, m, v):
super(AdamNet, self).__init__()
@ -318,6 +326,7 @@ def test_same_primal_used_by_multi_j_with_monad1():
def test_same_primal_used_by_multi_j_with_monad2():
context.set_context(mode=context.GRAPH_MODE)
class AdamNet(nn.Cell):
def __init__(self, var, m, v):
super(AdamNet, self).__init__()
@ -364,6 +373,7 @@ def test_grad_args_type_error1():
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
@ -373,6 +383,7 @@ def test_grad_args_type_error1():
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=2)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
@ -390,6 +401,7 @@ def test_grad_args_type_error2():
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
@ -399,6 +411,7 @@ def test_grad_args_type_error2():
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_by_list=2)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
@ -416,6 +429,7 @@ def test_grad_args_type_error3():
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
@ -425,6 +439,7 @@ def test_grad_args_type_error3():
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(sens_param=2)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
@ -442,6 +457,7 @@ def test_grad_net_is_none():
def __init__(self):
super(Net, self).__init__()
self.add = P.Add()
def construct(self, x, y):
out = self.add(x, y)
return out
@ -451,6 +467,7 @@ def test_grad_net_is_none():
super(GradNetWrtX, self).__init__()
self.net = P.Add()
self.grad_op = ops.GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(None)
return gradient_function(x, y)
@ -468,6 +485,7 @@ def test_grad_missing_net():
def __init__(self):
super(Net, self).__init__()
self.add = P.Add()
def construct(self, x, y):
out = self.add(x, y)
return out
@ -477,6 +495,7 @@ def test_grad_missing_net():
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op()
return gradient_function(x, y)
@ -516,7 +535,7 @@ def test_user_defined_bprop_inputs_size_error():
try:
grad_net(x, y)
except Exception as e:
assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only"\
assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only" \
in str(e)
@ -591,6 +610,7 @@ def test_grad_hook():
super(Net, self).__init__()
self.add = P.Add()
self.hook = P.HookBackward(var_hook_function)
def construct(self, x, y):
x = self.hook(x)
out = self.add(x, y)
@ -601,6 +621,7 @@ def test_grad_hook():
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
@ -612,3 +633,183 @@ def test_grad_hook():
except Exception as e:
assert "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative " \
"mode." in str(e)
def test_custom_cell_bprop_with_parameter():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the custom cell bprop use Parameter.
Expectation: Raise an error
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x * self.z
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
try:
GradNet(Net())(x, y)
except Exception as e:
assert "The user defined 'bprop' function in scope" in str(e)
assert "does not support using Parameter" in str(e)
def test_custom_cell_bprop_with_parameter_in_sub_cell():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the custom cell bprop use Parameter in sub-cell.
Expectation: Raise an error
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x * self.z
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
try:
GradNet(Net())(x, y)
except Exception as e:
assert "The user defined 'bprop' function in scope" in str(e)
assert "does not support using Parameter" in str(e)
def test_pynative_custom_cell_bprop_with_parameter():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the custom cell bprop use Parameter.
Expectation: Raise an error
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x * self.z
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
try:
GradNet(Net())(x, y)
except Exception as e:
assert "The user defined 'bprop' function does not support using Parameter" in str(e)
def test_pynative_custom_cell_bprop_with_parameter_in_sub_cell():
"""
Feature: Custom cell bprop
Description: Get the gradients of inputs when the custom cell bprop use Parameter in sub-cell.
Expectation: Raise an error
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.net = Net1()
def construct(self, x, y):
return self.net(x, y)
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.matmul = P.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x * self.z
dy = y + y
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, x, y):
grad_f = grad_all(self.net)
return grad_f(x, y)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
try:
GradNet(Net())(x, y)
except Exception as e:
assert "The user defined 'bprop' function does not support using Parameter" in str(e)

View File

@ -160,7 +160,7 @@ def test_user_define_bprop_check_parameter():
return ret
def bprop(self, x, out, dout):
return dout + x
return dout + self.par
class GradNet(nn.Cell):
def __init__(self, net):
@ -177,7 +177,7 @@ def test_user_define_bprop_check_parameter():
grad_net = GradNet(net)
with pytest.raises(RuntimeError) as ex:
ret = grad_net(x, sens)
assert "When user defines the net bprop, the 'Parameter' data type is not supported in the net." in str(ex.value)
assert "The user defined 'bprop' function does not support using Parameter." in str(ex.value)
def test_user_define_bprop_check_number():