forked from mindspore-Ecosystem/mindspore
!33497 Support custom cell bprop when using Parameter in primal graph
Merge pull request !33497 from YuJianfeng/custom_bprop
This commit is contained in:
commit
406eaf37b1
|
@ -539,9 +539,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|||
FuncGraphPtr bprop_graph = bprop->second.func_graph();
|
||||
resources_->manager()->AddFuncGraph(bprop_graph);
|
||||
|
||||
if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The Cell with user defined 'bprop' function in scope " << primal->output()->scope()->name()
|
||||
<< " does not support Parameter data type.\n"
|
||||
(void)parse::ResolveFuncGraph(bprop_graph, resources_);
|
||||
if (!bprop_graph->free_variables_nodes().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The user defined 'bprop' function in scope " << primal->output()->scope()->name()
|
||||
<< " does not support using Parameter.\n"
|
||||
<< trace::GetDebugInfo(bprop_graph->debug_info());
|
||||
}
|
||||
bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
|
||||
|
|
|
@ -53,8 +53,9 @@ KPrim g_k_prims;
|
|||
namespace {
|
||||
constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
|
||||
constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
|
||||
constexpr char serializable_bprop_ops[] = "serializable_bprop_ops";
|
||||
constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir";
|
||||
constexpr char kSerializableBpropOps[] = "serializable_bprop_ops";
|
||||
constexpr char kBpropMindirModule[] = "mindspore.ops.bprop_mindir";
|
||||
constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
|
||||
|
||||
#ifndef _WIN32
|
||||
std::string GetBpropDir() {
|
||||
|
@ -86,8 +87,8 @@ mindspore::HashSet<std::string> GetSerializableBpropList() {
|
|||
if (!BpropMindirDirExists()) {
|
||||
return serializable_bprop_list;
|
||||
}
|
||||
py::module mod = py::module::import(bprop_mindir_module);
|
||||
py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops);
|
||||
py::module mod = py::module::import(kBpropMindirModule);
|
||||
py::object serializable_bprop_ops_attr = mod.attr(kSerializableBpropOps);
|
||||
if (!py::isinstance<py::list>(serializable_bprop_ops_attr)) {
|
||||
MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list.";
|
||||
return serializable_bprop_list;
|
||||
|
@ -614,8 +615,9 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
}
|
||||
|
||||
AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) {
|
||||
// current_primal_fg may have extra parameters like u_monad, io_monad
|
||||
std::vector<AnfNodePtr> extra_args;
|
||||
// The primal fg may have extra parameters from lifted fv or u_monad and io_monad.
|
||||
std::vector<AnfNodePtr> extra_lifted_args;
|
||||
std::vector<AnfNodePtr> extra_monad_args;
|
||||
// caller had checked size() - 2 is greater than 0.
|
||||
auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
|
||||
if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
|
||||
|
@ -623,6 +625,18 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
|
||||
"Insert it. Extra parameters size: "
|
||||
<< current_primal_fg_param_size - bprop_fg_param_size;
|
||||
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
|
||||
for (size_t i = 0; i < current_primal_fg_param_size; ++i) {
|
||||
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg->parameters()[i]);
|
||||
MS_EXCEPTION_IF_NULL(primal_parameter);
|
||||
auto lifted = primal_parameter->user_data<bool>(kLiftedUserDataKey);
|
||||
if (lifted == nullptr || !*lifted) {
|
||||
break;
|
||||
}
|
||||
extra_lifted_args.push_back(
|
||||
bprop_fg->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), primal_parameter}));
|
||||
++bprop_fg_param_size;
|
||||
}
|
||||
for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
|
||||
const auto &primal_node = current_primal_fg->parameters()[i];
|
||||
AnfNodePtr extra_node;
|
||||
|
@ -639,7 +653,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
"as the 'out' and 'dout'.\n"
|
||||
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
||||
}
|
||||
extra_args.push_back(extra_node);
|
||||
extra_monad_args.push_back(extra_node);
|
||||
MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
|
||||
}
|
||||
}
|
||||
|
@ -652,9 +666,13 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
args.push_back(NewEnviron(bprop_fg));
|
||||
// The lifted parameters are put in front.
|
||||
if (!extra_lifted_args.empty()) {
|
||||
(void)args.insert(args.end(), extra_lifted_args.cbegin(), extra_lifted_args.cend());
|
||||
}
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
if (!extra_args.empty()) {
|
||||
args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
|
||||
if (!extra_monad_args.empty()) {
|
||||
(void)args.insert(args.end(), extra_monad_args.cbegin(), extra_monad_args.cend());
|
||||
}
|
||||
return NewCNode(args, bprop_fg);
|
||||
}
|
||||
|
@ -664,9 +682,14 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
constexpr char python_ops[] = "_tuple_add";
|
||||
auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
|
||||
auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
|
||||
if (!extra_args.empty()) {
|
||||
extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
auto extra_tuple = NewCNode(extra_args, bprop_fg);
|
||||
if (!extra_lifted_args.empty()) {
|
||||
(void)extra_lifted_args.insert(extra_lifted_args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
auto extra_tuple = NewCNode(extra_lifted_args, bprop_fg);
|
||||
tuple_env = NewCNode({tuple_add_ops, tuple_env, extra_tuple}, bprop_fg);
|
||||
}
|
||||
if (!extra_monad_args.empty()) {
|
||||
(void)extra_monad_args.insert(extra_monad_args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
auto extra_tuple = NewCNode(extra_monad_args, bprop_fg);
|
||||
auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
|
||||
return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
|
||||
}
|
||||
|
@ -676,6 +699,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &
|
|||
|
||||
static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
||||
std::vector<AnfNodePtr> *transf_args) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
// bprop_fg has been checked in caller
|
||||
// transform except the last 2 parameters: out, dout.
|
||||
const size_t last_parameter_sizes = 2;
|
||||
|
@ -694,7 +718,6 @@ static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphP
|
|||
void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
|
||||
const PrimitivePtr &primitive, const FuncGraphPtr &outer,
|
||||
std::vector<AnfNodePtr> *transf_args) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
|
||||
// Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
|
||||
auto effect_info = GetPrimEffectInfo(primitive);
|
||||
|
@ -714,12 +737,24 @@ template <typename T>
|
|||
void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
|
||||
const T ¤t_primal_fg, const FuncGraphPtr &outer,
|
||||
std::vector<AnfNodePtr> *transf_args) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
|
||||
constexpr size_t need_filter_size = 2;
|
||||
auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
|
||||
// current_primal_fg may have extra parameters after AutoMonad
|
||||
const auto ¤t_primal_fg_params = current_primal_fg->parameters();
|
||||
// The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
|
||||
for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
|
||||
auto primal_parameter = dyn_cast<Parameter>(current_primal_fg_params[i]);
|
||||
MS_EXCEPTION_IF_NULL(primal_parameter);
|
||||
auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
|
||||
if (lifted == nullptr || !*lifted) {
|
||||
break;
|
||||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal_parameter->debug_info()));
|
||||
auto transf_p = outer->add_parameter();
|
||||
transf_args->push_back(transf_p);
|
||||
++bprop_fg_param_size;
|
||||
}
|
||||
TransformNormalArgs(mng, bprop_fg, outer, transf_args);
|
||||
// Current primal fg may have extra parameters after AutoMonad
|
||||
if (bprop_fg_param_size < current_primal_fg_params.size()) {
|
||||
for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
|
||||
auto p = current_primal_fg_params[i];
|
||||
|
|
|
@ -2703,11 +2703,20 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Do grad for custom bprop";
|
||||
size_t par_number = py::tuple(python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
||||
if (par_number > 0) {
|
||||
MS_LOG(EXCEPTION) << "When user defines the net bprop, the 'Parameter' data type is not supported in the net.";
|
||||
}
|
||||
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
// When the co_names is empty, we will still get a tuple which is empty.
|
||||
auto co_names = py::getattr(code_obj, "co_names").cast<py::tuple>();
|
||||
for (auto name : co_names) {
|
||||
if (!py::hasattr(cell, name)) {
|
||||
continue;
|
||||
}
|
||||
auto var = py::getattr(cell, name);
|
||||
if (py::hasattr(var, "__parameter__") && py::isinstance<tensor::MetaTensor>(var)) {
|
||||
MS_LOG(EXCEPTION) << "The user defined 'bprop' function does not support using Parameter.";
|
||||
}
|
||||
}
|
||||
|
||||
auto bprop_func_cellid = GetId(bprop_func);
|
||||
bprop_cell_list_.emplace_back(bprop_func_cellid);
|
||||
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
|
||||
|
@ -2721,7 +2730,6 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
|
||||
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
|
||||
|
|
|
@ -300,7 +300,9 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
|
||||
auto &fg_params = repl_func_graph_params_[func_graph];
|
||||
(void)fg_params.emplace_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
|
||||
auto fv_parameter = AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var));
|
||||
fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
|
||||
(void)fg_params.emplace_back(fv_parameter);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -313,6 +315,11 @@ void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) {
|
|||
param->set_default_param(old_param->default_param());
|
||||
}
|
||||
param->set_name(old_param->name());
|
||||
constexpr char lifted_user_data_key[] = "lifted_from_fv";
|
||||
auto lifted = param->user_data<bool>(lifted_user_data_key);
|
||||
if (lifted != nullptr && *lifted) {
|
||||
param->set_user_data<bool>(lifted_user_data_key, std::make_shared<bool>(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,9 +26,6 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
|
@ -40,6 +37,7 @@ class MulAdd(nn.Cell):
|
|||
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
|
||||
return 2 * dout, 2 * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -59,6 +57,7 @@ class InlineMulADD(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.mul_add(x, y) + x + self.param * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -82,6 +81,7 @@ class WithParameter(nn.Cell):
|
|||
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
|
||||
return self.param1 * self.param2 * dout, 2 * y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -95,6 +95,7 @@ class WithNoBprop(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -104,6 +105,7 @@ def test_with_no_bprop():
|
|||
y = Tensor(2, dtype=ms.int32)
|
||||
assert grad_all(with_no_bprop)(x, y) == (2, 1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -142,6 +144,7 @@ def test_grad_in_bprop_1():
|
|||
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -183,6 +186,7 @@ def test_grad_in_bprop_2():
|
|||
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
|
||||
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -236,6 +240,7 @@ class OneInputBprop(nn.Cell):
|
|||
def bprop(self, x, out, dout):
|
||||
return (5 * x,)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -263,6 +268,7 @@ class InlineBpropTwoInput(nn.Cell):
|
|||
grads = grad_all(self.f)(x, y)
|
||||
return grads[0] * 2, grads[1] * 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -324,6 +330,7 @@ class InlineMutilTwoInputParameterCell(nn.Cell):
|
|||
output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -347,20 +354,24 @@ class MulAddWithParam(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.mul_add(self.param, x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_refkey_bprop():
|
||||
grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||
|
||||
def construct(self, x):
|
||||
weights = self.weights
|
||||
grads = grad_by_list(self.network, weights)(x)
|
||||
return grads
|
||||
|
||||
network = GradWrap(MulAddWithParam())
|
||||
input_data = Tensor(np.array([2, 2], np.float32))
|
||||
grads = network(input_data)
|
||||
|
@ -375,6 +386,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
|||
def bprop(self, x, y, out, dout):
|
||||
return (2 * dout,)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -392,6 +404,7 @@ class MulAddWithWrongOutputType(nn.Cell):
|
|||
def bprop(self, x, y, out, dout):
|
||||
return 2 * dout, 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -413,6 +426,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
|
|||
def bprop(self, x, y, out, dout):
|
||||
return 2, self.ones
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -421,3 +435,328 @@ def test_grad_mul_add_with_wrong_output_shape():
|
|||
mul_add = MulAddWithWrongOutputShape()
|
||||
with pytest.raises(TypeError):
|
||||
grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_forward_with_parameter():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the forward net using Parameter.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(out[0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[1].asnumpy(), expect_dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_forward_with_parameter_in_sub_cell():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(out[0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[1].asnumpy(), expect_dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_forward_with_parameter_in_sub_cell_get_by_list():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.params = ParameterTuple(net.trainable_params())
|
||||
self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = self.grad_op(self.net, self.params)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
expect_dz = np.array([0.0]).astype(np.float32)
|
||||
assert np.allclose(out[0][0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[0][1].asnumpy(), expect_dy)
|
||||
assert np.allclose(out[1][0].asnumpy(), expect_dz)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_forward_with_parameter():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the forward net using Parameter.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(out[0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[1].asnumpy(), expect_dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_forward_with_parameter_in_sub_cell():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
assert np.allclose(out[0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[1].asnumpy(), expect_dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_forward_with_parameter_in_sub_cell_get_by_list():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
|
||||
Expectation: Get the correct gradients.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x + x
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.params = ParameterTuple(net.trainable_params())
|
||||
self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = self.grad_op(self.net, self.params)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
out = GradNet(Net())(x, y)
|
||||
expect_dx = np.array([[1.0, 1.2, 0.8],
|
||||
[2.4, 2.6, 2.2]]).astype(np.float32)
|
||||
expect_dy = np.array([[0.02, 0.6, 2.2],
|
||||
[0.2, 0.4, 2.6],
|
||||
[4.2, 2.4, 6.6]]).astype(np.float32)
|
||||
expect_dz = np.array([0.0]).astype(np.float32)
|
||||
assert np.allclose(out[0][0].asnumpy(), expect_dx)
|
||||
assert np.allclose(out[0][1].asnumpy(), expect_dy)
|
||||
assert np.allclose(out[1][0].asnumpy(), expect_dz)
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -25,6 +26,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
grad_by_list = C.GradOperation(get_by_list=True)
|
||||
|
||||
|
||||
class CropAndResizeNet(nn.Cell):
|
||||
def __init__(self, crop_size):
|
||||
super(CropAndResizeNet, self).__init__()
|
||||
|
@ -102,6 +104,7 @@ class BPropOperatatorNet(nn.Cell):
|
|||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_user_defined_bprop_with_u():
|
||||
net = BPropOperatatorNet(mul_size=(128, 96))
|
||||
grad_net = TestUserDefinedBpropGradNet(net)
|
||||
|
@ -153,6 +156,7 @@ def test_second_grad_with_j_primitive():
|
|||
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
|
||||
def test_ad_fv_cnode_order():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
|
||||
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
|
||||
|
@ -164,6 +168,7 @@ def test_ad_fv_cnode_order():
|
|||
return xay
|
||||
|
||||
return second_level() + xay
|
||||
|
||||
return first_level()
|
||||
|
||||
input_x = Tensor(np.array([1.0], dtype=np.float32))
|
||||
|
@ -178,6 +183,7 @@ def test_ad_fv_cnode_order():
|
|||
# True and False branch of switch have different number of parameters.
|
||||
def test_if_branch_with_different_params():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -215,6 +221,7 @@ def test_if_branch_with_different_params():
|
|||
# because weight1 in Net may use old_parameter other than replicated one.
|
||||
def test_limit_lift_fv_scope():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -276,6 +283,7 @@ def test_same_primal_used_by_multi_j():
|
|||
|
||||
def test_same_primal_used_by_multi_j_with_monad1():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class AdamNet(nn.Cell):
|
||||
def __init__(self, var, m, v):
|
||||
super(AdamNet, self).__init__()
|
||||
|
@ -318,6 +326,7 @@ def test_same_primal_used_by_multi_j_with_monad1():
|
|||
|
||||
def test_same_primal_used_by_multi_j_with_monad2():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class AdamNet(nn.Cell):
|
||||
def __init__(self, var, m, v):
|
||||
super(AdamNet, self).__init__()
|
||||
|
@ -364,6 +373,7 @@ def test_grad_args_type_error1():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
@ -373,6 +383,7 @@ def test_grad_args_type_error1():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(get_all=2)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
@ -390,6 +401,7 @@ def test_grad_args_type_error2():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
@ -399,6 +411,7 @@ def test_grad_args_type_error2():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(get_by_list=2)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
@ -416,6 +429,7 @@ def test_grad_args_type_error3():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
@ -425,6 +439,7 @@ def test_grad_args_type_error3():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(sens_param=2)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
@ -442,6 +457,7 @@ def test_grad_net_is_none():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.add(x, y)
|
||||
return out
|
||||
|
@ -451,6 +467,7 @@ def test_grad_net_is_none():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = P.Add()
|
||||
self.grad_op = ops.GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(None)
|
||||
return gradient_function(x, y)
|
||||
|
@ -468,6 +485,7 @@ def test_grad_missing_net():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.add(x, y)
|
||||
return out
|
||||
|
@ -477,6 +495,7 @@ def test_grad_missing_net():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op()
|
||||
return gradient_function(x, y)
|
||||
|
@ -516,7 +535,7 @@ def test_user_defined_bprop_inputs_size_error():
|
|||
try:
|
||||
grad_net(x, y)
|
||||
except Exception as e:
|
||||
assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only"\
|
||||
assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only" \
|
||||
in str(e)
|
||||
|
||||
|
||||
|
@ -591,6 +610,7 @@ def test_grad_hook():
|
|||
super(Net, self).__init__()
|
||||
self.add = P.Add()
|
||||
self.hook = P.HookBackward(var_hook_function)
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.hook(x)
|
||||
out = self.add(x, y)
|
||||
|
@ -601,6 +621,7 @@ def test_grad_hook():
|
|||
super(GradNetWrtX, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation()
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(x, y)
|
||||
|
@ -612,3 +633,183 @@ def test_grad_hook():
|
|||
except Exception as e:
|
||||
assert "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative " \
|
||||
"mode." in str(e)
|
||||
|
||||
|
||||
def test_custom_cell_bprop_with_parameter():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the custom cell bprop use Parameter.
|
||||
Expectation: Raise an error
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x * self.z
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
try:
|
||||
GradNet(Net())(x, y)
|
||||
except Exception as e:
|
||||
assert "The user defined 'bprop' function in scope" in str(e)
|
||||
assert "does not support using Parameter" in str(e)
|
||||
|
||||
|
||||
def test_custom_cell_bprop_with_parameter_in_sub_cell():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the custom cell bprop use Parameter in sub-cell.
|
||||
Expectation: Raise an error
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x * self.z
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
try:
|
||||
GradNet(Net())(x, y)
|
||||
except Exception as e:
|
||||
assert "The user defined 'bprop' function in scope" in str(e)
|
||||
assert "does not support using Parameter" in str(e)
|
||||
|
||||
|
||||
def test_pynative_custom_cell_bprop_with_parameter():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the custom cell bprop use Parameter.
|
||||
Expectation: Raise an error
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x * self.z
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
try:
|
||||
GradNet(Net())(x, y)
|
||||
except Exception as e:
|
||||
assert "The user defined 'bprop' function does not support using Parameter" in str(e)
|
||||
|
||||
|
||||
def test_pynative_custom_cell_bprop_with_parameter_in_sub_cell():
|
||||
"""
|
||||
Feature: Custom cell bprop
|
||||
Description: Get the gradients of inputs when the custom cell bprop use Parameter in sub-cell.
|
||||
Expectation: Raise an error
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.net = Net1()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.net(x, y)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
||||
|
||||
def construct(self, x, y):
|
||||
x = x * self.z
|
||||
out = self.matmul(x, y)
|
||||
return out
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
dx = x * self.z
|
||||
dy = y + y
|
||||
return dx, dy
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_f = grad_all(self.net)
|
||||
return grad_f(x, y)
|
||||
|
||||
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
try:
|
||||
GradNet(Net())(x, y)
|
||||
except Exception as e:
|
||||
assert "The user defined 'bprop' function does not support using Parameter" in str(e)
|
||||
|
|
|
@ -160,7 +160,7 @@ def test_user_define_bprop_check_parameter():
|
|||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return dout + x
|
||||
return dout + self.par
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
|
@ -177,7 +177,7 @@ def test_user_define_bprop_check_parameter():
|
|||
grad_net = GradNet(net)
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "When user defines the net bprop, the 'Parameter' data type is not supported in the net." in str(ex.value)
|
||||
assert "The user defined 'bprop' function does not support using Parameter." in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_number():
|
||||
|
|
Loading…
Reference in New Issue