Fix bug of switch layer join
This commit is contained in:
parent
a117b6dc14
commit
947e19b839
|
@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
|||
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
|
||||
if (dtype == nullptr) {
|
||||
*data = std::make_shared<Int32Imm>(obj);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto int_dypte = dyn_cast<Int>(dtype);
|
||||
if (int_dypte != nullptr) {
|
||||
switch (int_dypte->nbits()) {
|
||||
case 8:
|
||||
*data = std::make_shared<Int8Imm>(static_cast<int8_t>(obj));
|
||||
break;
|
||||
case 16:
|
||||
*data = std::make_shared<Int16Imm>(obj);
|
||||
break;
|
||||
case 32:
|
||||
*data = std::make_shared<Int32Imm>(obj);
|
||||
break;
|
||||
case 64:
|
||||
*data = std::make_shared<Int64Imm>(obj);
|
||||
break;
|
||||
default:
|
||||
*data = std::make_shared<Int32Imm>(obj);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
auto uint_dypte = dyn_cast<UInt>(dtype);
|
||||
if (int_dypte != nullptr) {
|
||||
switch (uint_dypte->nbits()) {
|
||||
case 8:
|
||||
*data = std::make_shared<UInt8Imm>(obj);
|
||||
break;
|
||||
case 16:
|
||||
*data = std::make_shared<UInt16Imm>(obj);
|
||||
break;
|
||||
case 32:
|
||||
*data = std::make_shared<UInt32Imm>(obj);
|
||||
break;
|
||||
case 64:
|
||||
*data = std::make_shared<UInt64Imm>(obj);
|
||||
break;
|
||||
default:
|
||||
*data = std::make_shared<UInt32Imm>(obj);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
auto float_dypte = dyn_cast<Float>(dtype);
|
||||
if (float_dypte != nullptr) {
|
||||
switch (float_dypte->nbits()) {
|
||||
case 32:
|
||||
*data = std::make_shared<FP32Imm>(obj);
|
||||
break;
|
||||
case 64:
|
||||
*data = std::make_shared<FP64Imm>(obj);
|
||||
break;
|
||||
default:
|
||||
*data = std::make_shared<FP32Imm>(obj);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
|
||||
if (dtype == nullptr) {
|
||||
*data = std::make_shared<FP32Imm>(obj);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto float_dypte = dyn_cast<Float>(dtype);
|
||||
if (float_dypte == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (float_dypte->nbits()) {
|
||||
case 32:
|
||||
*data = std::make_shared<FP32Imm>(obj);
|
||||
break;
|
||||
case 64:
|
||||
*data = std::make_shared<FP64Imm>(obj);
|
||||
break;
|
||||
default:
|
||||
*data = std::make_shared<FP32Imm>(obj);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) {
|
||||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
|
||||
// check parameter valid
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data is null pointer";
|
||||
|
@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
} else if (py::isinstance<py::bool_>(obj)) {
|
||||
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
converted = std::make_shared<Int32Imm>(py::cast<int>(obj));
|
||||
ret = ConvertIntegerWithType(py::cast<int>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
converted = std::make_shared<FP32Imm>(py::cast<float>(obj));
|
||||
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
|
|
|
@ -139,7 +139,7 @@ enum ClassInstanceTypeDef {
|
|||
};
|
||||
|
||||
// Convert python object to ValuePtr
|
||||
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false);
|
||||
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr);
|
||||
|
||||
// Convert python obj to graph
|
||||
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
|
||||
|
|
|
@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
|
|||
|
||||
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
|
||||
// Convert to AbstractValue based on type and shape
|
||||
auto out_dtype = output["dtype"];
|
||||
if (output["value"].is_none()) {
|
||||
auto out_shape = output["shape"];
|
||||
auto out_dtype = output["dtype"];
|
||||
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
|
||||
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
|
||||
|
||||
|
@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
}
|
||||
// Convert pyobject to Value, then to AbstractValue
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(output["value"], &converted_ret);
|
||||
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
|
||||
bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
}
|
||||
|
|
|
@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
MS_LOG(EXCEPTION) << "value is null";
|
||||
}
|
||||
py::object ret;
|
||||
if (value->isa<Int32Imm>()) {
|
||||
MS_LOG(DEBUG) << "int";
|
||||
if (value->isa<Int8Imm>()) {
|
||||
MS_LOG(DEBUG) << "int8";
|
||||
py::int_ v = value->cast<Int8ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<Int16Imm>()) {
|
||||
MS_LOG(DEBUG) << "int16";
|
||||
py::int_ v = value->cast<Int16ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<Int32Imm>()) {
|
||||
MS_LOG(DEBUG) << "int32";
|
||||
py::int_ v = value->cast<Int32ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
MS_LOG(DEBUG) << "int64";
|
||||
py::int_ v = value->cast<Int64ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<UInt8Imm>()) {
|
||||
MS_LOG(DEBUG) << "uint8";
|
||||
py::int_ v = value->cast<UInt8ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<UInt16Imm>()) {
|
||||
MS_LOG(DEBUG) << "uint16";
|
||||
py::int_ v = value->cast<UInt16ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<UInt32Imm>()) {
|
||||
MS_LOG(DEBUG) << "uint32";
|
||||
py::int_ v = value->cast<UInt32ImmPtr>()->value();
|
||||
ret = v;
|
||||
} else if (value->isa<UInt64Imm>()) {
|
||||
MS_LOG(DEBUG) << "uint64";
|
||||
py::int_ v = value->cast<UInt64ImmPtr>()->value();
|
||||
|
|
|
@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
|||
}
|
||||
auto value_self = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_self);
|
||||
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
|
||||
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
|
||||
if (res_type == kAnyType) {
|
||||
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString()
|
||||
<< ", type2 = " << other->GetTypeTrack()->ToString();
|
||||
}
|
||||
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
|
||||
if (res_value == value_self) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
}
|
||||
|
|
|
@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
|
|||
if (*shape1 == *shape2) {
|
||||
return shape1;
|
||||
}
|
||||
// lengths of two shapes are not same, join failed
|
||||
if (shape1->shape().size() != shape2->shape().size()) {
|
||||
MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString();
|
||||
return shape1;
|
||||
// special case: shape(1), shape() -> shape(1)
|
||||
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
|
||||
return shape1;
|
||||
}
|
||||
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
|
||||
return shape2;
|
||||
}
|
||||
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
|
||||
<< ", shape2 = " << shape2->ToString();
|
||||
}
|
||||
std::vector<int> dims;
|
||||
bool has_dynamic_shape = false;
|
||||
|
|
|
@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "I8(" << v_ << ")";
|
||||
oss << "I8(" << int(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "I16(" << v_ << ")";
|
||||
oss << "I16(" << int(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "I32(" << v_ << ")";
|
||||
oss << "I32(" << int(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "U8(" << v_ << ")";
|
||||
oss << "U8(" << unsigned(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "U16(" << v_ << ")";
|
||||
oss << "U16(" << unsigned(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm {
|
|||
|
||||
std::string DumpText() const override {
|
||||
std::ostringstream oss;
|
||||
oss << "U32(" << v_ << ")";
|
||||
oss << "U32(" << unsigned(v_) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker):
|
|||
self.input_selector = [i for i in range(self.nin)]
|
||||
|
||||
def get_sens(self, i):
|
||||
return 1
|
||||
return 1.0
|
||||
|
||||
def check_against_numeric(self, out_index):
|
||||
args = list(self.args)
|
||||
|
|
|
@ -911,3 +911,73 @@ def test_recursive_call():
|
|||
with pytest.raises(RuntimeError):
|
||||
net(input_data)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
||||
|
||||
def test_switch_layer_shape_join_failed():
|
||||
class AddFuncNet(nn.Cell):
|
||||
def __init__(self, funcs, new_func):
|
||||
super(AddFuncNet, self).__init__()
|
||||
self.funcs = funcs
|
||||
self.new_func = new_func
|
||||
|
||||
def construct(self, i, inputs):
|
||||
final_funcs = self.funcs + (self.new_func,)
|
||||
x = final_funcs[i](inputs)
|
||||
return x
|
||||
|
||||
class ReLUTuple(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReLUTuple, self).__init__()
|
||||
self.op = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x[0])
|
||||
|
||||
func1 = nn.Softmax()
|
||||
func2 = nn.ReLU()
|
||||
func3 = ReLUTuple()
|
||||
|
||||
funcs = (func1, func2)
|
||||
|
||||
|
||||
net = AddFuncNet(funcs, func3)
|
||||
|
||||
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(1, mstype.int32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
net(i, inp)
|
||||
|
||||
|
||||
def test_switch_layer_dtype_join_failed():
|
||||
class Cast(nn.Cell):
|
||||
def __init__(self, dtype):
|
||||
super(Cast, self).__init__()
|
||||
self.op = P.Cast()
|
||||
self.dtype = dtype
|
||||
|
||||
def construct(self, x):
|
||||
y = self.op(x, self.dtype)
|
||||
return y + y
|
||||
|
||||
class SwitchNegNet(nn.Cell):
|
||||
def __init__(self, funcs):
|
||||
super(SwitchNegNet, self).__init__()
|
||||
self.funcs = funcs
|
||||
self.op = P.Neg()
|
||||
|
||||
def construct(self, i, inputs):
|
||||
x = self.funcs[i](inputs)
|
||||
x = self.op(x)
|
||||
return x
|
||||
|
||||
|
||||
func1 = nn.ReLU()
|
||||
func2 = Cast(mstype.int32)
|
||||
funcs = (func1, func2)
|
||||
net = SwitchNegNet(funcs)
|
||||
|
||||
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(0, mstype.int32)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(i, inp)
|
||||
|
|
|
@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
|
||||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||
from ....ops_common import convert
|
||||
|
||||
class InputBackward(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -1699,7 +1700,7 @@ test_case_nn_ops = [
|
|||
('ResizeBilinear', {
|
||||
'block': P.ResizeBilinear((5, 5)),
|
||||
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],
|
||||
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)]}),
|
||||
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float32)]}),
|
||||
('ResizeBilinearGrad', {
|
||||
'block': G.ResizeBilinearGrad(),
|
||||
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
|
||||
|
@ -1708,7 +1709,7 @@ test_case_nn_ops = [
|
|||
('ROIAlign', {
|
||||
'block': P.ROIAlign(7, 7, 0.03125, 2),
|
||||
'desc_inputs': [[2, 256, 192, 320], [1024, 5]],
|
||||
'desc_bprop': [[7, 7]]}),
|
||||
'desc_bprop': [[1024, 256, 7, 7]]}),
|
||||
('ROIAlignGrad', {
|
||||
'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2),
|
||||
'desc_inputs': [[1, 1, 2, 2], [1, 5]],
|
||||
|
@ -2311,7 +2312,7 @@ test_case_other_ops = [
|
|||
('IOU', {
|
||||
'block': P.IOU(),
|
||||
'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))],
|
||||
'desc_bprop': [[128, 256]]}),
|
||||
'desc_bprop': [convert([128, 256], np.float16)]}),
|
||||
('Summary', {
|
||||
'block': SummaryNet(),
|
||||
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
|
||||
|
|
|
@ -118,29 +118,29 @@ test_case_reid_ops = [
|
|||
'desc_inputs': [[256, 8]],
|
||||
'desc_bprop': [[256, 8]]}),
|
||||
('Pow', {
|
||||
'block': P.Pow(), # 输入有标量插件产生了段错误。
|
||||
'block': P.Pow(),
|
||||
'desc_const': [2.0],
|
||||
'desc_inputs': [[1, 512]],
|
||||
'desc_bprop': [[1, 512]]}),
|
||||
('LogicalNot', {
|
||||
'block': P.LogicalNot(),
|
||||
'desc_inputs': [convert([256], np.bool_)],
|
||||
'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。
|
||||
'desc_bprop': [convert([256], np.bool_)]}),
|
||||
('Equal', {
|
||||
'block': P.Equal(),
|
||||
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
|
||||
'desc_bprop': [[256]]}),
|
||||
'desc_bprop': [convert([256], np.bool_)]}),
|
||||
('Greater', {
|
||||
'block': P.Greater(),
|
||||
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
|
||||
'desc_bprop': [[256]]}),
|
||||
'desc_bprop': [convert([256], np.bool_)]}),
|
||||
('Dropout', {
|
||||
'block': nn.Dropout(),
|
||||
'desc_inputs': [[1, 512, 7, 7]],
|
||||
'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。
|
||||
'desc_bprop': [[1, 512, 7, 7]]}),
|
||||
('MatMul', {
|
||||
'block': P.MatMul(),
|
||||
'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。
|
||||
'desc_inputs': [[64, 512], [512, 64]],
|
||||
'desc_bprop': [[64, 64]]}),
|
||||
('Maximum', {
|
||||
'block': P.Maximum(),
|
||||
|
|
|
@ -77,8 +77,8 @@ class Bprop(Cell):
|
|||
self.grad = grad_op
|
||||
self.with_sens = False
|
||||
self.sens = sens
|
||||
if sens:
|
||||
self.sens = Tensor(sens, dtype=mstype.float32)
|
||||
if not sens is None:
|
||||
self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32)
|
||||
self.with_sens = True
|
||||
|
||||
def construct(self, *inputs):
|
||||
|
@ -108,7 +108,7 @@ def test_all_var_args_grad_with_sens():
|
|||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
sens = Tensor(1.0, dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
net = VarNet(SecondNet())
|
||||
grad_net = GradNet(net)
|
||||
_ = grad_net(x, y, sens)
|
||||
|
@ -160,7 +160,7 @@ def test_grad_all_var_args_with_sens():
|
|||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
sens = Tensor(1.0, dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
net = VarNet(SecondNet())
|
||||
grad_net = GradNet(net)
|
||||
_ = grad_net(x, y, sens)
|
||||
|
@ -178,7 +178,7 @@ def test_grad_var_args_with_sens():
|
|||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
sens = Tensor(1.0, dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
net = VarNet(SecondNet())
|
||||
grad_net = GradNet(net)
|
||||
_ = grad_net(x, y, sens)
|
||||
|
@ -237,7 +237,7 @@ def test_var_args_grad():
|
|||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
sens = Tensor(1.0, dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
net = VarNet(SecondNet())
|
||||
grad_net = GradNet(net)
|
||||
_ = grad_net(x, y, sens)
|
||||
|
@ -285,14 +285,14 @@ def test_grad_within_if_else():
|
|||
self.net = net
|
||||
grad_op = C.GradOperation(
|
||||
name='grad', get_all=False, get_by_list=True, sens_param=True)
|
||||
self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0)
|
||||
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
self.grad = Bprop(self.net, True, self.weights, grad_op, sens)
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.grad(*inputs)
|
||||
|
||||
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
_ = Tensor(1.0, dtype=mstype.float32)
|
||||
net = VarNet(SecondNet())
|
||||
grad_net = GradNet(net)
|
||||
out = grad_net(x, y)
|
||||
|
|
Loading…
Reference in New Issue