forked from mindspore-Ecosystem/mindspore
!4689 add check for user define bprop in Pynative mode
Merge pull request !4689 from zhangbuxue/add_check_for_user_define_bprop_in_Pynative_mode
This commit is contained in:
commit
24700afa2d
|
@ -143,7 +143,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
}
|
||||
}
|
||||
|
||||
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
|
||||
if (max_type_id == kNumberTypeUInt8 && has_int8) {
|
||||
max_type_id = kNumberTypeInt16;
|
||||
}
|
||||
// if bool is the max type, see if there is scalar input
|
||||
|
|
|
@ -446,7 +446,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
GetGeBackendPolicy();
|
||||
#endif
|
||||
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
|
||||
std::string phase_s = py::cast<std::string>(phase);
|
||||
auto phase_s = py::cast<std::string>(phase);
|
||||
MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
|
||||
ResourcePtr resource = std::make_shared<Resource>(obj);
|
||||
|
||||
|
@ -541,6 +541,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
|
|||
} catch (const py::index_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::index_error(ex);
|
||||
} catch (const py::key_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::key_error(ex);
|
||||
} catch (const py::attribute_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::attribute_error(ex);
|
||||
|
|
|
@ -175,6 +175,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|||
TypeId max_type = TypeId::kTypeUnknown;
|
||||
bool has_float = false;
|
||||
bool has_int = false;
|
||||
bool has_int8 = false;
|
||||
for (size_t index : indexes) {
|
||||
if (!has_float && py::isinstance<py::float_>(py_args[index])) {
|
||||
has_float = true;
|
||||
|
@ -191,6 +192,9 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|||
if (type_priority == prim::type_map.end()) {
|
||||
continue;
|
||||
}
|
||||
if (arg_type_id == kNumberTypeInt8) {
|
||||
has_int8 = true;
|
||||
}
|
||||
if (type_priority->second > priority) {
|
||||
max_type = type_priority->first;
|
||||
priority = type_priority->second;
|
||||
|
@ -205,6 +209,9 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|||
max_type = TypeId::kNumberTypeFloat32;
|
||||
}
|
||||
}
|
||||
if (max_type == TypeId::kNumberTypeUInt8 && has_int8) {
|
||||
max_type = TypeId::kNumberTypeInt16;
|
||||
}
|
||||
(void)dst_type.insert(std::make_pair(type, max_type));
|
||||
}
|
||||
return dst_type;
|
||||
|
|
|
@ -39,6 +39,9 @@ class PyExceptionInitializer {
|
|||
if (exception_type == TypeError) {
|
||||
throw py::type_error(str);
|
||||
}
|
||||
if (exception_type == KeyError) {
|
||||
throw py::key_error(str);
|
||||
}
|
||||
if (exception_type == AttributeError) {
|
||||
throw py::attribute_error(str);
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
|
@ -77,9 +78,47 @@ py::function PrimitivePy::GetBpropFunction() {
|
|||
}
|
||||
}
|
||||
|
||||
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
|
||||
py::tuple grads;
|
||||
if (!py::isinstance<py::tuple>(grads_obj)) {
|
||||
grads = py::make_tuple(grads_obj);
|
||||
} else {
|
||||
grads = py::cast<py::tuple>(grads_obj);
|
||||
}
|
||||
if (grads.size() != py_args.size() - 2) {
|
||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size()
|
||||
<< " is not equal to the args number: " << py_args.size() - 2 << ".";
|
||||
}
|
||||
if (MsContext::GetInstance()->check_bprop_flag()) {
|
||||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
|
||||
<< "th arg should be Tensor, but got "
|
||||
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
||||
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
|
||||
}
|
||||
|
||||
py::tuple grad_shape = grads[i].attr("shape");
|
||||
py::object grad_dtype = grads[i].attr("dtype");
|
||||
py::tuple arg_shape = py_args[i].attr("shape");
|
||||
py::object arg_dtype = py_args[i].attr("dtype");
|
||||
if (!grad_shape.equal(arg_shape) || grad_dtype != arg_dtype) {
|
||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same shape and dtype as the " << i << "th arg, but the "
|
||||
<< i << "th arg shape: " << py::cast<py::str>(arg_shape)
|
||||
<< " and dtype: " << py::cast<py::str>(arg_dtype)
|
||||
<< ", the gradient shape: " << py::cast<py::str>(grad_shape)
|
||||
<< " and dtype: " << py::cast<py::str>(grad_dtype) << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return grads;
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
||||
py::tuple py_args = ConvertDatatoPyTuple(args);
|
||||
py::object obj;
|
||||
bool is_bprop = this->HasAttr(kBpropAttrName);
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
|
@ -90,11 +129,13 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
|
||||
: py_args[i];
|
||||
}
|
||||
obj = hook_(*convert_args);
|
||||
return std::make_shared<PyObjectRef>(obj);
|
||||
py::object grads_obj = hook_(*convert_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
}
|
||||
SyncData(py_args[2]);
|
||||
bool is_cell = this->HasAttr(kCellHookAttrName);
|
||||
py::object obj;
|
||||
if (is_cell) {
|
||||
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
||||
auto iter = hook_grad_.find(cell_id);
|
||||
|
|
|
@ -440,16 +440,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
// inputs (a.k.a args in current function) size less than parameters'.
|
||||
if (output->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
|
||||
if (args.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs size is 0, let graph to be executed.";
|
||||
}
|
||||
// Find the right parameter as ret_val.
|
||||
auto func_graph = output->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto params = func_graph->parameters();
|
||||
if (params.empty()) {
|
||||
MS_EXCEPTION(UnknownError) << "Graph's parameters size is 0";
|
||||
}
|
||||
if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
|
||||
<< " not equal to graph input size " << params.size() << ", let graph to be executed.";
|
||||
|
|
|
@ -55,7 +55,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
auto key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
|
@ -72,7 +72,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
auto key_string = GetValue<std::string>(keyPtr);
|
||||
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
|
|||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_input = GetValue<std::string>(key_value);
|
||||
auto key_input = GetValue<std::string>(key_value);
|
||||
std::string key_actual = kwarg->get_key();
|
||||
if (key_actual != key_input) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||
|
@ -216,7 +216,7 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
if (it == dict_elems.end()) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
MS_EXCEPTION(KeyError) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_str = GetValue<std::string>(key_value);
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
|
|
@ -147,6 +147,7 @@ static std::string ExceptionTypeToString(ExceptionType type) {
|
|||
_TO_STRING(IndexError),
|
||||
_TO_STRING(ValueError),
|
||||
_TO_STRING(TypeError),
|
||||
_TO_STRING(KeyError),
|
||||
_TO_STRING(AttributeError),
|
||||
};
|
||||
// clang-format on
|
||||
|
@ -236,7 +237,7 @@ void LogWriter::operator^(const LogStream &stream) const {
|
|||
std::ostringstream oss;
|
||||
oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] ";
|
||||
if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError &&
|
||||
exception_type_ != ValueError && exception_type_ != AttributeError) {
|
||||
exception_type_ != ValueError && exception_type_ != KeyError && exception_type_ != AttributeError) {
|
||||
oss << ExceptionTypeToString(exception_type_) << " ";
|
||||
}
|
||||
oss << msg.str();
|
||||
|
|
|
@ -60,6 +60,7 @@ enum ExceptionType {
|
|||
IndexError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
};
|
||||
|
||||
|
|
|
@ -88,6 +88,16 @@ def test_float_tensor_and_bool_tensors_add():
|
|||
y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_))
|
||||
ret_actual = x + y
|
||||
ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
assert ret_actual.dtype == ret_expect.dtype
|
||||
assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all()
|
||||
|
||||
|
||||
def test_int8_tensor_and_uint8_tensors_add():
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8))
|
||||
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8))
|
||||
ret_actual = x + y
|
||||
ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int16))
|
||||
assert ret_actual.dtype == ret_expect.dtype
|
||||
assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all()
|
||||
|
||||
|
||||
|
@ -165,7 +175,6 @@ def test_float_tensor_and_int_tensors_sub_grad():
|
|||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
ret = grad_net(x, y, sens)
|
||||
print(ret)
|
||||
assert ret[0].dtype == x.dtype
|
||||
assert ret[1].dtype == y.dtype
|
||||
assert (ret[0].asnumpy() == sens.asnumpy()).all()
|
||||
|
@ -194,7 +203,6 @@ def test_float16_tensor_and_float32_tensors_sub_grad():
|
|||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
ret = grad_net(x, y, sens)
|
||||
print(ret)
|
||||
assert ret[0].dtype == x.dtype
|
||||
assert ret[1].dtype == y.dtype
|
||||
assert (ret[0].asnumpy() == sens.asnumpy()).all()
|
||||
|
@ -224,3 +232,31 @@ def test_float_tensor_and_int_add_grad():
|
|||
ret = grad_net(x, sens)
|
||||
assert ret[0].dtype == x.dtype
|
||||
assert (ret[0].asnumpy() == sens.asnumpy()).all()
|
||||
|
||||
|
||||
def test_int8_tensor_and_uint8_tensors_add_grad():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return x + y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, y, sens)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8))
|
||||
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8))
|
||||
sens = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16))
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
ret = grad_net(x, y, sens)
|
||||
assert ret[0].dtype == x.dtype
|
||||
assert ret[1].dtype == y.dtype
|
||||
assert (ret[0].asnumpy() == sens.asnumpy()).all()
|
||||
assert (ret[1].asnumpy() == sens.asnumpy()).all()
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test implicit conversion """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, nn, context, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def test_user_define_bprop_check_ok():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.grad = Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float32))
|
||||
|
||||
def construct(self, x):
|
||||
ret = x * 2
|
||||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return (self.grad * 3,)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
ret = grad_net(x, sens)
|
||||
assert ret[0].shape == (2, 3)
|
||||
assert ret[0].dtype == mstype.float32
|
||||
assert (ret[0].asnumpy() == np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], np.float32) * 3).all()
|
||||
|
||||
|
||||
def test_user_define_bprop_no_check_dtype():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.grad = Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float16))
|
||||
|
||||
def construct(self, x):
|
||||
ret = x * 2
|
||||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return (self.grad * 3,)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=False)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
ret = grad_net(x, sens)
|
||||
assert ret[0].shape == (2, 3)
|
||||
assert ret[0].dtype == mstype.float16
|
||||
assert (ret[0].asnumpy() == np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], np.float16) * 3).all()
|
||||
|
||||
|
||||
def test_user_define_bprop_check_shape():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.grad = Tensor(np.array([[1.1, 2.2], [2.0, 3.0]], dtype=np.float32))
|
||||
|
||||
def construct(self, x):
|
||||
ret = x * 2
|
||||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return (self.grad * 3,)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_dtype():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.grad = Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float16))
|
||||
|
||||
def construct(self, x):
|
||||
ret = x * 2
|
||||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return (self.grad * 3,)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.par = Parameter(Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float32)), name="par")
|
||||
self.grad = Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float16))
|
||||
|
||||
def construct(self, x):
|
||||
ret = x * 2 + self.par
|
||||
return ret
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
return dout + x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "in scope Default does not support Parameter data type." in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_number():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.grad = Tensor(np.array([[1.1, 2.2, 3.3], [2.0, 3.0, 4.0]], dtype=np.float32))
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = x * 2 + y
|
||||
return ret
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return (dout,)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y, sens):
|
||||
return C.grad_all_with_sens(self.net)(x, y, sens)
|
||||
|
||||
x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
y = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
ret = grad_net(x, y, sens)
|
||||
assert "For user define net bprop, the gradients number: 1 is not equal to the args number: 2." in str(ex.value)
|
Loading…
Reference in New Issue