forked from mindspore-Ecosystem/mindspore
validate bprop rules
This commit is contained in:
parent
e00f17369d
commit
9e633b6c12
|
@ -695,6 +695,7 @@ REGISTER_PYBIND_DEFINE(
|
||||||
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
||||||
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
||||||
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
||||||
|
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const TypePtr kTypeExternal = std::make_shared<External>();
|
const TypePtr kTypeExternal = std::make_shared<External>();
|
||||||
|
|
|
@ -213,6 +213,7 @@ const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_orig
|
||||||
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||||
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
||||||
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
||||||
|
const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||||
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
||||||
|
|
||||||
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||||
|
|
|
@ -220,6 +220,7 @@ extern const PrimitivePtr kPrimInsertGradientOf;
|
||||||
extern const PrimitivePtr kPrimPrintShapeType;
|
extern const PrimitivePtr kPrimPrintShapeType;
|
||||||
extern const PrimitivePtr kPrimPrint;
|
extern const PrimitivePtr kPrimPrint;
|
||||||
extern const PrimitivePtr kPrimSameTypeShape;
|
extern const PrimitivePtr kPrimSameTypeShape;
|
||||||
|
extern const PrimitivePtr kPrimCheckBprop;
|
||||||
extern const PrimitivePtr kPrimDepend;
|
extern const PrimitivePtr kPrimDepend;
|
||||||
extern const PrimitivePtr kPrimStateSetItem;
|
extern const PrimitivePtr kPrimStateSetItem;
|
||||||
extern const PrimitivePtr kPrimScalarSummary;
|
extern const PrimitivePtr kPrimScalarSummary;
|
||||||
|
|
|
@ -309,14 +309,6 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
||||||
auto bprop = primal->transforms().find("bprop");
|
auto bprop = primal->transforms().find("bprop");
|
||||||
if (bprop != primal->transforms().end()) {
|
if (bprop != primal->transforms().end()) {
|
||||||
FuncGraphPtr bprop_graph = bprop->second.func_graph();
|
FuncGraphPtr bprop_graph = bprop->second.func_graph();
|
||||||
const size_t param_diff = 1;
|
|
||||||
if (bprop_graph->output()->isa<CNode>() &&
|
|
||||||
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
|
|
||||||
// It does not matter with the final tangents, just a tip for debugging
|
|
||||||
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
|
||||||
<< primal->output()->scope()->name()
|
|
||||||
<< " output must be a tuple and output number should be the same with inputs.";
|
|
||||||
}
|
|
||||||
resources_->manager()->AddFuncGraph(bprop_graph);
|
resources_->manager()->AddFuncGraph(bprop_graph);
|
||||||
|
|
||||||
if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) {
|
if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) {
|
||||||
|
|
|
@ -127,7 +127,7 @@ class KPrim {
|
||||||
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
|
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
|
||||||
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
||||||
std::vector<AnfNodePtr> *const transf_args);
|
std::vector<AnfNodePtr> *const transf_args);
|
||||||
void AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg);
|
void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check);
|
||||||
|
|
||||||
Registry bprop_registry_;
|
Registry bprop_registry_;
|
||||||
std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
|
std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
|
||||||
|
@ -137,10 +137,7 @@ template <typename T>
|
||||||
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
||||||
MS_EXCEPTION_IF_NULL(primal);
|
MS_EXCEPTION_IF_NULL(primal);
|
||||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||||
|
CheckBprop(bprop_fg, primal->ToString());
|
||||||
if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
|
|
||||||
AddCheckTypeShapeOp(bprop_fg);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto debug_info = std::make_shared<GraphDebugInfo>();
|
auto debug_info = std::make_shared<GraphDebugInfo>();
|
||||||
debug_info->set_name(primal->ToString());
|
debug_info->set_name(primal->ToString());
|
||||||
|
|
|
@ -50,9 +50,13 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||||
grad_op_child_scope_prefix + prim->name());
|
grad_op_child_scope_prefix + prim->name());
|
||||||
ScopeGuard scope_guard(scope);
|
ScopeGuard scope_guard(scope);
|
||||||
py::function fn = prim->GetBpropFunction();
|
py::function fn = prim->GetBpropFunction();
|
||||||
|
if (fn == nullptr || py::isinstance<py::none>(fn)) {
|
||||||
|
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
|
FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ".";
|
MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return func_graph;
|
return func_graph;
|
||||||
|
@ -153,31 +157,23 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KPrim::AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg) {
|
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
|
||||||
// bprop_fg has been checked in caller
|
// bprop_fg has been checked in caller
|
||||||
auto same_type_shape = prim::GetPythonOps("same_type_shape", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(same_type_shape);
|
MS_EXCEPTION_IF_NULL(check_bprop);
|
||||||
|
check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check));
|
||||||
|
|
||||||
std::vector<AnfNodePtr> bout_input;
|
std::vector<AnfNodePtr> inputs;
|
||||||
bout_input.push_back(NewValueNode(prim::kPrimMakeTuple));
|
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2);
|
||||||
|
AnfNodePtr params = bprop_fg->NewCNode(inputs);
|
||||||
|
|
||||||
auto fg_out = bprop_fg->output();
|
inputs.clear();
|
||||||
MS_EXCEPTION_IF_NULL(fg_out);
|
inputs.push_back(NewValueNode(check_bprop));
|
||||||
auto cnode = fg_out->cast<CNodePtr>();
|
inputs.push_back(bprop_fg->output());
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
inputs.push_back(params);
|
||||||
|
AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
|
||||||
auto &inputs = cnode->inputs();
|
bprop_fg->set_output(bprop_out);
|
||||||
auto params = bprop_fg->parameters();
|
|
||||||
std::vector<AnfNodePtr> sub_input;
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
sub_input.clear();
|
|
||||||
sub_input.push_back(NewValueNode(same_type_shape));
|
|
||||||
sub_input.push_back(inputs[i]);
|
|
||||||
sub_input.push_back(params[i - 1]);
|
|
||||||
bout_input.push_back(bprop_fg->NewCNode(sub_input));
|
|
||||||
}
|
|
||||||
AnfNodePtr cbout = bprop_fg->NewCNode(bout_input);
|
|
||||||
bprop_fg->set_output(cbout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
|
FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
|
||||||
|
|
|
@ -67,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
|
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
|
||||||
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
|
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
|
||||||
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
|
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||||
|
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||||
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
||||||
|
|
||||||
// Env Item Eliminate
|
// Env Item Eliminate
|
||||||
|
|
|
@ -45,6 +45,7 @@ class OptimizeIRPassLib {
|
||||||
SubstitutionPtr reduce_eliminate_;
|
SubstitutionPtr reduce_eliminate_;
|
||||||
SubstitutionPtr partial_eliminate_;
|
SubstitutionPtr partial_eliminate_;
|
||||||
SubstitutionPtr same_eliminate_;
|
SubstitutionPtr same_eliminate_;
|
||||||
|
SubstitutionPtr check_bprop_eliminate_;
|
||||||
SubstitutionPtr reset_defer_inline_;
|
SubstitutionPtr reset_defer_inline_;
|
||||||
|
|
||||||
// Env Item Eliminate
|
// Env Item Eliminate
|
||||||
|
|
|
@ -109,6 +109,25 @@ class SameEliminater : public AnfVisitor {
|
||||||
AnfNodePtr x_{nullptr};
|
AnfNodePtr x_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// {prim::kPrimCheckBprop, X, Y} -> X
|
||||||
|
class CheckBpropEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
x_ = nullptr;
|
||||||
|
AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node);
|
||||||
|
return x_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const AnfNodePtr &node) override {
|
||||||
|
if (x_ == nullptr) {
|
||||||
|
x_ = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr x_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
// Reset defer_inline flag
|
// Reset defer_inline flag
|
||||||
class ResetDeferInline : public AnfVisitor {
|
class ResetDeferInline : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
});
|
});
|
||||||
opt::OptPassConfig a_3 = opt::OptPassConfig({
|
opt::OptPassConfig a_3 = opt::OptPassConfig({
|
||||||
irpass.same_eliminate_,
|
irpass.same_eliminate_,
|
||||||
|
irpass.check_bprop_eliminate_,
|
||||||
irpass.replace_applicator_,
|
irpass.replace_applicator_,
|
||||||
});
|
});
|
||||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||||
|
|
|
@ -295,6 +295,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
dic["shape"] = shape;
|
dic["shape"] = shape;
|
||||||
dic["dtype"] = arg_slice->BuildType();
|
dic["dtype"] = arg_slice->BuildType();
|
||||||
dic["value"] = BuildValue(arg_slice->BuildValue());
|
dic["value"] = BuildValue(arg_slice->BuildValue());
|
||||||
|
} else if (abs_base->isa<AbstractRef>()) {
|
||||||
|
auto value = abs_base->cast<AbstractRefPtr>()->ref();
|
||||||
|
dic = ConvertAbstractToPython(value);
|
||||||
} else if (abs_base->isa<AbstractTuple>()) {
|
} else if (abs_base->isa<AbstractTuple>()) {
|
||||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||||
size_t len = arg_tuple->size();
|
size_t len = arg_tuple->size();
|
||||||
|
@ -327,6 +330,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
dic["shape"] = py::none();
|
dic["shape"] = py::none();
|
||||||
dic["dtype"] = py::none();
|
dic["dtype"] = py::none();
|
||||||
dic["value"] = py::none();
|
dic["value"] = py::none();
|
||||||
|
} else if (abs_base->isa<AbstractFunction>()) {
|
||||||
|
dic["shape"] = py::none();
|
||||||
|
dic["dtype"] = abs_base->BuildType();
|
||||||
|
dic["value"] = py::none();
|
||||||
} else {
|
} else {
|
||||||
auto value = abs_base->BuildValue();
|
auto value = abs_base->BuildValue();
|
||||||
if ((*value == *kAnyValue)) {
|
if ((*value == *kAnyValue)) {
|
||||||
|
|
|
@ -85,13 +85,16 @@ list_ = typing.List()
|
||||||
tuple_ = typing.Tuple()
|
tuple_ = typing.Tuple()
|
||||||
tensor = typing.TensorType()
|
tensor = typing.TensorType()
|
||||||
function = typing.Function()
|
function = typing.Function()
|
||||||
|
function_type = typing.Function
|
||||||
symbolic_key = typing.SymbolicKeyType()
|
symbolic_key = typing.SymbolicKeyType()
|
||||||
env_type = typing.EnvType()
|
env_type = typing.EnvType()
|
||||||
|
env_type_type = typing.EnvType
|
||||||
type_type = typing.TypeType()
|
type_type = typing.TypeType()
|
||||||
type_none = typing.TypeNone()
|
type_none = typing.TypeNone()
|
||||||
string = typing.String()
|
string = typing.String()
|
||||||
type_refkey = typing.RefKeyType()
|
type_refkey = typing.RefKeyType()
|
||||||
tensor_type = typing.TensorType
|
tensor_type = typing.TensorType
|
||||||
|
anything_type = typing.TypeAnything
|
||||||
|
|
||||||
number_type = (int8,
|
number_type = (int8,
|
||||||
int16,
|
int16,
|
||||||
|
|
|
@ -211,11 +211,11 @@ def get_bprop_slice(self):
|
||||||
|
|
||||||
def bprop(x, begin, size, out, dout):
|
def bprop(x, begin, size, out, dout):
|
||||||
dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
|
dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
|
||||||
return (dx,)
|
return (dx, zeros_like(begin), zeros_like(size))
|
||||||
|
|
||||||
def bprop_gpu(x, begin, size, out, dout):
|
def bprop_gpu(x, begin, size, out, dout):
|
||||||
dx = dx = G.SliceGrad()(dout, x, begin, size)
|
dx = dx = G.SliceGrad()(dout, x, begin, size)
|
||||||
return (dx,)
|
return (dx, zeros_like(begin), zeros_like(size))
|
||||||
|
|
||||||
if context.get_context('device_target') == "GPU":
|
if context.get_context('device_target') == "GPU":
|
||||||
return bprop_gpu
|
return bprop_gpu
|
||||||
|
@ -262,7 +262,7 @@ def get_bprop_gather_v2(self):
|
||||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||||
perm_2 = _generate_inverse_index(x_shp, axis)
|
perm_2 = _generate_inverse_index(x_shp, axis)
|
||||||
params_grad = transpose(params_grad, perm_2)
|
params_grad = transpose(params_grad, perm_2)
|
||||||
return params_grad, zeros_like(indices)
|
return params_grad, zeros_like(indices), zeros_like(axis)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -505,7 +505,7 @@ def get_bprop_reducemax(self):
|
||||||
|
|
||||||
def bprop(x, axis, out, dout):
|
def bprop(x, axis, out, dout):
|
||||||
dx = _min_or_max_grad(x, axis, out, dout)
|
dx = _min_or_max_grad(x, axis, out, dout)
|
||||||
return (dx,)
|
return (dx, zeros_like(axis))
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -528,7 +528,7 @@ def get_bprop_reducemin(self):
|
||||||
|
|
||||||
def bprop(x, axis, out, dout):
|
def bprop(x, axis, out, dout):
|
||||||
dx = _min_or_max_grad(x, axis, out, dout)
|
dx = _min_or_max_grad(x, axis, out, dout)
|
||||||
return (dx,)
|
return (dx, zeros_like(axis))
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -436,7 +436,7 @@ def get_bprop_onehot(self):
|
||||||
"""Grad definition for `OneHot` operation."""
|
"""Grad definition for `OneHot` operation."""
|
||||||
|
|
||||||
def bprop(indices, depth, on_value, off_value, out, dout):
|
def bprop(indices, depth, on_value, off_value, out, dout):
|
||||||
return zeros_like(indices), zeros_like(depth)
|
return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,10 @@ def _zeros_like_scala(x):
|
||||||
"""Returns 0 which has the same dtype as x where x is a scalar."""
|
"""Returns 0 which has the same dtype as x where x is a scalar."""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@zeros_like_leaf.register("Bool")
|
||||||
|
def _zeros_like_bool(x):
|
||||||
|
"""Returns False if x is a bool."""
|
||||||
|
return False
|
||||||
|
|
||||||
newenv = base.EnvInstance_()
|
newenv = base.EnvInstance_()
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ tensor_pow = P.Pow()
|
||||||
tensor_mod = P.FloorMod()
|
tensor_mod = P.FloorMod()
|
||||||
strided_slice = P.StridedSlice()
|
strided_slice = P.StridedSlice()
|
||||||
same_type_shape = P.SameTypeShape()
|
same_type_shape = P.SameTypeShape()
|
||||||
|
check_bprop = P.CheckBprop()
|
||||||
equal = P.Equal()
|
equal = P.Equal()
|
||||||
not_equal = P.NotEqual()
|
not_equal = P.NotEqual()
|
||||||
assign_sub = P.AssignSub()
|
assign_sub = P.AssignSub()
|
||||||
|
|
|
@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
|
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
|
||||||
ApplyRMSProp, ApplyCenteredRMSProp)
|
ApplyRMSProp, ApplyCenteredRMSProp)
|
||||||
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
|
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
|
||||||
from . import _quant_ops
|
from . import _quant_ops
|
||||||
from ._quant_ops import *
|
from ._quant_ops import *
|
||||||
|
|
||||||
|
@ -179,6 +179,7 @@ __all__ = [
|
||||||
'GeSwitch',
|
'GeSwitch',
|
||||||
'Merge',
|
'Merge',
|
||||||
'SameTypeShape',
|
'SameTypeShape',
|
||||||
|
'CheckBprop',
|
||||||
'CheckValid',
|
'CheckValid',
|
||||||
'BoundingBoxEncode',
|
'BoundingBoxEncode',
|
||||||
'BoundingBoxDecode',
|
'BoundingBoxDecode',
|
||||||
|
|
|
@ -269,3 +269,66 @@ class MakeRefKey(Primitive):
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CheckBprop(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Checks whether data type and shape of corresponding element from tuple x and y are the same.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If not the same.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked.
|
||||||
|
- **input_y** (tuple[Tensor]) - The input_y contains the inputs of bprop to check against.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
(tuple[Tensor]), the input_x,
|
||||||
|
if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
|
||||||
|
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
|
||||||
|
>>> out = P.CheckBprop()(input_x, input_y)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init CheckBprop"""
|
||||||
|
|
||||||
|
def infer_shape(self, xshapes, yshapes):
|
||||||
|
tips = f'Bprop of {self.prim_to_check}'
|
||||||
|
if len(xshapes) < len(yshapes):
|
||||||
|
raise TypeError(f"{tips}, the size of output should be {len(yshapes)},"
|
||||||
|
f" but got {len(xshapes)}.")
|
||||||
|
checking_range = len(yshapes)
|
||||||
|
for i in range(checking_range):
|
||||||
|
xshape = xshapes[i]
|
||||||
|
yshape = yshapes[i]
|
||||||
|
if not xshape or not yshape:
|
||||||
|
continue
|
||||||
|
if xshape != yshape:
|
||||||
|
raise TypeError(f"{tips}, the shape of {i}th output should be {yshape},"
|
||||||
|
f" but got {xshape}.")
|
||||||
|
return xshapes
|
||||||
|
|
||||||
|
def infer_dtype(self, xdtypes, ydtypes):
|
||||||
|
tips = f'Bprop of {self.prim_to_check}'
|
||||||
|
if len(xdtypes) < len(ydtypes):
|
||||||
|
raise TypeError(f"{tips}, the size of output should be {len(ydtypes)},"
|
||||||
|
f" but got {len(xdtypes)}.")
|
||||||
|
checking_range = len(ydtypes)
|
||||||
|
for i in range(checking_range):
|
||||||
|
xdtype = xdtypes[i]
|
||||||
|
ydtype = ydtypes[i]
|
||||||
|
if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type):
|
||||||
|
continue
|
||||||
|
if isinstance(ydtype, mstype.function_type):
|
||||||
|
if not isinstance(xdtype, mstype.env_type_type):
|
||||||
|
raise TypeError(f"{tips}, the dtype of {i}th output should be {mstype.env_type_type},"
|
||||||
|
f" but got {xdtype}.")
|
||||||
|
continue
|
||||||
|
if xdtype != ydtype:
|
||||||
|
raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype},"
|
||||||
|
f" but got {xdtype}.")
|
||||||
|
return xdtypes
|
||||||
|
|
|
@ -317,7 +317,7 @@ test_case_cell_ops = [
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
dropout_prob=0.1),
|
dropout_prob=0.1),
|
||||||
'desc_inputs': [[1, 768], [1, 768]],
|
'desc_inputs': [[1, 768], [1, 768]],
|
||||||
'desc_bprop': [[1, 128, 768]]}), # maybe not right
|
'desc_bprop': [[1, 768]]}),
|
||||||
('BertTransformer_2', {
|
('BertTransformer_2', {
|
||||||
'block': bert_trans(),
|
'block': bert_trans(),
|
||||||
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
|
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
|
||||||
|
@ -331,7 +331,7 @@ test_case_cell_ops = [
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
||||||
'num_output': 3}), # maybe not right
|
'num_output': 3}),
|
||||||
|
|
||||||
('BertModel_1', {
|
('BertModel_1', {
|
||||||
'block': BertModel(config=BertConfig(batch_size=1,
|
'block': BertModel(config=BertConfig(batch_size=1,
|
||||||
|
@ -342,7 +342,7 @@ test_case_cell_ops = [
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
||||||
'num_output': 3}), # maybe not right
|
'num_output': 3}),
|
||||||
|
|
||||||
('BertModel_2', {
|
('BertModel_2', {
|
||||||
'block': BertModel(config=BertConfig(batch_size=1,
|
'block': BertModel(config=BertConfig(batch_size=1,
|
||||||
|
@ -354,7 +354,7 @@ test_case_cell_ops = [
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
||||||
'num_output': 3}), # maybe not right
|
'num_output': 3}),
|
||||||
|
|
||||||
('BertPretrainingLoss', {
|
('BertPretrainingLoss', {
|
||||||
'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
|
'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
|
||||||
|
|
|
@ -175,7 +175,7 @@ class GetParamGrad(nn.Cell):
|
||||||
|
|
||||||
def test_grad_conv_prelu():
|
def test_grad_conv_prelu():
|
||||||
shapes = [[64, 64, 112, 112]]
|
shapes = [[64, 64, 112, 112]]
|
||||||
outshape = [[64, 64, 56, 56]]
|
outshape = [[64, 64, 112, 112]]
|
||||||
net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True)
|
net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True)
|
||||||
inputs = [convert(shp, dtype=np.float16) for shp in shapes]
|
inputs = [convert(shp, dtype=np.float16) for shp in shapes]
|
||||||
sens_shape = outshape[0]
|
sens_shape = outshape[0]
|
||||||
|
|
|
@ -585,7 +585,7 @@ test_case_nn_ops = [
|
||||||
('ReLUV2', {
|
('ReLUV2', {
|
||||||
'block': P.ReLUV2(),
|
'block': P.ReLUV2(),
|
||||||
'desc_inputs': [[1, 3, 4, 4]],
|
'desc_inputs': [[1, 3, 4, 4]],
|
||||||
'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}),
|
'desc_bprop': [[1, 3, 4, 4], ([1, 1, 4, 4, 2], {'dtype': np.uint8})]}),
|
||||||
('ReLUGrad', {
|
('ReLUGrad', {
|
||||||
'block': G.ReluGrad(),
|
'block': G.ReluGrad(),
|
||||||
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
|
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
|
||||||
|
@ -626,7 +626,7 @@ test_case_nn_ops = [
|
||||||
('MaxPoolWithArgmax', {
|
('MaxPoolWithArgmax', {
|
||||||
'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
|
'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
|
||||||
'desc_inputs': [[128, 32, 32, 64]],
|
'desc_inputs': [[128, 32, 32, 64]],
|
||||||
'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}),
|
'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}),
|
||||||
('SoftmaxCrossEntropyWithLogits', {
|
('SoftmaxCrossEntropyWithLogits', {
|
||||||
'block': P.SoftmaxCrossEntropyWithLogits(),
|
'block': P.SoftmaxCrossEntropyWithLogits(),
|
||||||
'desc_inputs': [[1, 10], [1, 10]],
|
'desc_inputs': [[1, 10], [1, 10]],
|
||||||
|
@ -639,7 +639,7 @@ test_case_nn_ops = [
|
||||||
('LogSoftmax', {
|
('LogSoftmax', {
|
||||||
'block': P.LogSoftmax(),
|
'block': P.LogSoftmax(),
|
||||||
'desc_inputs': [[64, 2]],
|
'desc_inputs': [[64, 2]],
|
||||||
'desc_bprop': [[160, 30522]]}),
|
'desc_bprop': [[64, 2]]}),
|
||||||
('LogSoftmaxGrad', {
|
('LogSoftmaxGrad', {
|
||||||
'block': G.LogSoftmaxGrad(),
|
'block': G.LogSoftmaxGrad(),
|
||||||
'desc_inputs': [[16, 1234], [16, 1234]],
|
'desc_inputs': [[16, 1234], [16, 1234]],
|
||||||
|
@ -648,7 +648,7 @@ test_case_nn_ops = [
|
||||||
('LayerNorm', {
|
('LayerNorm', {
|
||||||
'block': P.LayerNorm(),
|
'block': P.LayerNorm(),
|
||||||
'desc_inputs': [[2, 16], [16], [16]],
|
'desc_inputs': [[2, 16], [16], [16]],
|
||||||
'desc_bprop': [[2, 16], [2, 16], [2, 16]]}),
|
'desc_bprop': [[2, 16], [2, 1], [2, 1]]}),
|
||||||
('LayerNormGrad', {
|
('LayerNormGrad', {
|
||||||
'block': G.LayerNormGrad(),
|
'block': G.LayerNormGrad(),
|
||||||
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
||||||
|
@ -845,7 +845,7 @@ test_case_nn_ops = [
|
||||||
'block': P.OneHot(),
|
'block': P.OneHot(),
|
||||||
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
|
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
|
||||||
'desc_inputs': [Tensor(np.array([64]).astype(np.int32))],
|
'desc_inputs': [Tensor(np.array([64]).astype(np.int32))],
|
||||||
'desc_bprop': [[64, 2]]}),
|
'desc_bprop': [[1, 3]]}),
|
||||||
('ReduceProd_0', {
|
('ReduceProd_0', {
|
||||||
'block': P.ReduceProd(),
|
'block': P.ReduceProd(),
|
||||||
'desc_const': [0],
|
'desc_const': [0],
|
||||||
|
@ -950,7 +950,7 @@ test_case_array_ops = [
|
||||||
'block': P.Cast(),
|
'block': P.Cast(),
|
||||||
'desc_const': [mstype.int32],
|
'desc_const': [mstype.int32],
|
||||||
'desc_inputs': [[2, 3, 4, 5]],
|
'desc_inputs': [[2, 3, 4, 5]],
|
||||||
'desc_bprop': [Tensor(np.ones((2, 3, 3, 5)).astype(np.int32))]}),
|
'desc_bprop': [Tensor(np.ones((2, 3, 4, 5)).astype(np.int32))]}),
|
||||||
('ExpandDims', {
|
('ExpandDims', {
|
||||||
'block': P.ExpandDims(),
|
'block': P.ExpandDims(),
|
||||||
'desc_const': [0],
|
'desc_const': [0],
|
||||||
|
@ -1002,12 +1002,12 @@ test_case_array_ops = [
|
||||||
'desc_inputs': [
|
'desc_inputs': [
|
||||||
(Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)),
|
(Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)),
|
||||||
Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))],
|
Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))],
|
||||||
'desc_bprop': [[4, 2]]}),
|
'desc_bprop': [([4, 2], {'dtype': np.int32})]}),
|
||||||
('ConcatV2_1', {
|
('ConcatV2_1', {
|
||||||
'block': P.Concat(axis=2),
|
'block': P.Concat(axis=2),
|
||||||
'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)),
|
'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)),
|
||||||
Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))],
|
Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))],
|
||||||
'desc_bprop': [[2, 1, 5]]}),
|
'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}),
|
||||||
('ConcatV2_2', {
|
('ConcatV2_2', {
|
||||||
'block': NetForConcat(),
|
'block': NetForConcat(),
|
||||||
'desc_inputs': [[2, 2]],
|
'desc_inputs': [[2, 2]],
|
||||||
|
@ -1042,7 +1042,7 @@ test_case_array_ops = [
|
||||||
('Pack_2', {
|
('Pack_2', {
|
||||||
'block': NetForPackInput(P.Pack()),
|
'block': NetForPackInput(P.Pack()),
|
||||||
'desc_inputs':[[2, 2]],
|
'desc_inputs':[[2, 2]],
|
||||||
'desc_bprop':[[2, 2, 2]],
|
'desc_bprop':[[1, 2, 2]],
|
||||||
}),
|
}),
|
||||||
('Pack_3', {
|
('Pack_3', {
|
||||||
'block': NetForPackInput(P.Pack()),
|
'block': NetForPackInput(P.Pack()),
|
||||||
|
@ -1077,7 +1077,7 @@ test_case_array_ops = [
|
||||||
('SpaceToBatch_2', {
|
('SpaceToBatch_2', {
|
||||||
'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]),
|
'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]),
|
||||||
'desc_inputs': [[1, 3, 2, 2]],
|
'desc_inputs': [[1, 3, 2, 2]],
|
||||||
'desc_bprop': [[4, 3, 2, 4]],
|
'desc_bprop': [[4, 3, 2, 3]],
|
||||||
}),
|
}),
|
||||||
('BatchToSpace_1', {
|
('BatchToSpace_1', {
|
||||||
'block': P.BatchToSpace(2, [[0, 0], [0, 0]]),
|
'block': P.BatchToSpace(2, [[0, 0], [0, 0]]),
|
||||||
|
@ -1124,7 +1124,7 @@ test_case_other_ops = [
|
||||||
'desc_const': [(3, 3)],
|
'desc_const': [(3, 3)],
|
||||||
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
||||||
Tensor(np.ones((2,), np.int32))),
|
Tensor(np.ones((2,), np.int32))),
|
||||||
'desc_bprop': [[3, 3]]}),
|
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
||||||
('SmoothL1Loss', {
|
('SmoothL1Loss', {
|
||||||
'block': P.SmoothL1Loss(),
|
'block': P.SmoothL1Loss(),
|
||||||
'desc_inputs': [[256, 4], [256, 4]],
|
'desc_inputs': [[256, 4], [256, 4]],
|
||||||
|
|
|
@ -229,12 +229,6 @@ class TwoInputBprop(nn.Cell):
|
||||||
def bprop(self, x, y, out, dout):
|
def bprop(self, x, y, out, dout):
|
||||||
return 5 * x, 8 * y
|
return 5 * x, 8 * y
|
||||||
|
|
||||||
class TwoInput(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.op = P.Mul()
|
|
||||||
def construct(self, x, y):
|
|
||||||
return self.op(x, y)
|
|
||||||
|
|
||||||
class TwoInputWithParameter(nn.Cell):
|
class TwoInputWithParameter(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -301,8 +295,37 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
return 2 * x + y
|
return 2 * x + y
|
||||||
def bprop(self, x, y, out, dout):
|
def bprop(self, x, y, out, dout):
|
||||||
return 2 * dout, 2 * y, out
|
return 2 * dout,
|
||||||
|
|
||||||
def test_grad_mul_add_with_wrong_output_num():
|
def test_grad_mul_add_with_wrong_output_num():
|
||||||
mul_add = MulAddWithWrongOutputNum()
|
mul_add = MulAddWithWrongOutputNum()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
C.grad_all(mul_add)(1, 2)
|
C.grad_all(mul_add)(1, 2)
|
||||||
|
|
||||||
|
class MulAddWithWrongOutputType(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MulAddWithWrongOutputType, self).__init__()
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2 * x + y
|
||||||
|
def bprop(self, x, y, out, dout):
|
||||||
|
return 2 * dout, 2
|
||||||
|
|
||||||
|
def test_grad_mul_add_with_wrong_output_type():
|
||||||
|
mul_add = MulAddWithWrongOutputType()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
|
||||||
|
|
||||||
|
|
||||||
|
class MulAddWithWrongOutputShape(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MulAddWithWrongOutputShape, self).__init__()
|
||||||
|
self.ones = Tensor(np.ones([2,]))
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2 * x + y
|
||||||
|
def bprop(self, x, y, out, dout):
|
||||||
|
return 2, self.ones
|
||||||
|
|
||||||
|
def test_grad_mul_add_with_wrong_output_shape():
|
||||||
|
mul_add = MulAddWithWrongOutputShape()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
|
||||||
|
|
|
@ -32,6 +32,8 @@ from ....mindspore_test_framework.utils.check_gradient import (
|
||||||
OperationGradChecker, check_gradient, ScalarGradChecker)
|
OperationGradChecker, check_gradient, ScalarGradChecker)
|
||||||
from ....mindspore_test_framework.utils.bprop_util import bprop
|
from ....mindspore_test_framework.utils.bprop_util import bprop
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
from mindspore.ops._grad.grad_base import bprop_getters
|
||||||
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||||
|
|
||||||
|
|
||||||
def setup_module(module):
|
def setup_module(module):
|
||||||
|
@ -721,3 +723,94 @@ def test_grad_if_defer_inline():
|
||||||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||||
grads = C.grad_all(network)(inp)
|
grads = C.grad_all(network)(inp)
|
||||||
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
||||||
|
|
||||||
|
def test_bprop_with_wrong_output_num():
|
||||||
|
class BpropWithWrongOutputNum(PrimitiveWithInfer):
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, yshape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, y_type):
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
@bprop_getters.register(BpropWithWrongOutputNum)
|
||||||
|
def get_bprop_with_wrong_output_num(self):
|
||||||
|
"""Generate bprop for BpropWithWrongOutputNum"""
|
||||||
|
def bprop(x, y, out, dout):
|
||||||
|
return (dout,)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
class BpropWithWrongOutputNumCell(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputNumCell, self).__init__()
|
||||||
|
def construct(self, x, y):
|
||||||
|
return BpropWithWrongOutputNum()(x, y)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
||||||
|
|
||||||
|
def test_bprop_with_wrong_output_type():
|
||||||
|
class BpropWithWrongOutputType(PrimitiveWithInfer):
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type):
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
@bprop_getters.register(BpropWithWrongOutputType)
|
||||||
|
def get_bprop_with_wrong_output_type(self):
|
||||||
|
"""Generate bprop for BpropWithWrongOutputType"""
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
return (1,)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
class BpropWithWrongOutputTypeCell(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputTypeCell, self).__init__()
|
||||||
|
def construct(self, x):
|
||||||
|
return BpropWithWrongOutputType()(x)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||||
|
|
||||||
|
def test_bprop_with_wrong_output_shape():
|
||||||
|
class BpropWithWrongOutputShape(PrimitiveWithInfer):
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type):
|
||||||
|
return x_type
|
||||||
|
|
||||||
|
@bprop_getters.register(BpropWithWrongOutputShape)
|
||||||
|
def get_bprop_with_wrong_output_shape(self):
|
||||||
|
"""Generate bprop for BpropWithWrongOutputShape"""
|
||||||
|
ones = Tensor(np.ones([2,]).astype(np.int32))
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
return (ones,)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
class BpropWithWrongOutputShapeCell(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(BpropWithWrongOutputShapeCell, self).__init__()
|
||||||
|
def construct(self, x):
|
||||||
|
return BpropWithWrongOutputShape()(x)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||||
|
|
|
@ -79,7 +79,7 @@ def test_InsertGradientOf_2():
|
||||||
summary = P.ScalarSummary()
|
summary = P.ScalarSummary()
|
||||||
def debug_gradient(dx):
|
def debug_gradient(dx):
|
||||||
""" debug_gradient """
|
""" debug_gradient """
|
||||||
dx = summary("dx: ", dx)
|
summary("dx: ", dx)
|
||||||
return dx
|
return dx
|
||||||
|
|
||||||
debug = P.InsertGradientOf(debug_gradient)
|
debug = P.InsertGradientOf(debug_gradient)
|
||||||
|
|
Loading…
Reference in New Issue