forked from mindspore-Ecosystem/mindspore
improve grad of first input
This commit is contained in:
parent
b6183f718f
commit
899d6114a4
|
@ -400,8 +400,18 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
|
||||||
op = prim::kPrimListGetItem;
|
op = prim::kPrimListGetItem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (tail_type_ == kGradFirst) {
|
||||||
|
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && (*sequeue)[1]->isa<abstract::AbstractUndetermined>()) {
|
||||||
|
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
|
||||||
|
} else {
|
||||||
|
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 1; i < sequeue->size(); ++i) {
|
for (size_t i = 1; i < sequeue->size(); ++i) {
|
||||||
if (do_grad_) {
|
if (tail_type_ == kGradAll) {
|
||||||
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
||||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) {
|
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) {
|
||||||
elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
|
elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
|
||||||
|
@ -581,8 +591,8 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
||||||
|
|
||||||
CNodePtr inputs_bprop = nullptr;
|
CNodePtr inputs_bprop = nullptr;
|
||||||
if (get_all_) {
|
if (get_all_) {
|
||||||
TailPtr tail = std::make_shared<Tail>("tail", true);
|
TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
|
||||||
inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app});
|
inputs_bprop = k_child->NewCNode({NewValueNode(tail_grad_all), b_app});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gradients wrt inputs and parameters
|
// Gradients wrt inputs and parameters
|
||||||
|
@ -602,11 +612,11 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
||||||
k_child->set_output(inputs_bprop);
|
k_child->set_output(inputs_bprop);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gradients wrt first input.
|
// Gradients wrt first input.
|
||||||
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
|
||||||
k_child->set_output(
|
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
|
||||||
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(1))}));
|
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
|
||||||
|
k_child->set_output(k_child->NewCNode({NewValueNode(tail_grad_first), b_app}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the graph.
|
// Generate the graph.
|
||||||
|
|
|
@ -97,9 +97,11 @@ using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
|
||||||
|
|
||||||
extern ValuePtr kCompositeHyperMap;
|
extern ValuePtr kCompositeHyperMap;
|
||||||
|
|
||||||
|
enum TailType { kGradAll, kGradFirst, kNotGrad };
|
||||||
|
|
||||||
class Tail : public MetaFuncGraph {
|
class Tail : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit Tail(const std::string &name, bool do_grad = false) : MetaFuncGraph(name), do_grad_(do_grad) {}
|
explicit Tail(const std::string &name, TailType tail_type = kNotGrad) : MetaFuncGraph(name), tail_type_(tail_type) {}
|
||||||
~Tail() override = default;
|
~Tail() override = default;
|
||||||
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
||||||
|
|
||||||
|
@ -109,7 +111,7 @@ class Tail : public MetaFuncGraph {
|
||||||
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool do_grad_;
|
TailType tail_type_;
|
||||||
};
|
};
|
||||||
using TailPtr = std::shared_ptr<Tail>;
|
using TailPtr = std::shared_ptr<Tail>;
|
||||||
|
|
||||||
|
|
|
@ -24,9 +24,9 @@ from mindspore.ops import composite as C
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class FirstInputTupleNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(FirstInputTupleNet, self).__init__()
|
||||||
|
|
||||||
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
||||||
if flag:
|
if flag:
|
||||||
|
@ -35,11 +35,11 @@ class Net(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
class GradNet(nn.Cell):
|
class GradNet(nn.Cell):
|
||||||
def __init__(self, net):
|
def __init__(self, net, get_all):
|
||||||
super(GradNet, self).__init__()
|
super(GradNet, self).__init__()
|
||||||
self.forward_net = net
|
self.forward_net = net
|
||||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||||
self.grad_all = C.GradOperation(get_all=True)
|
self.grad_all = C.GradOperation(get_all=get_all)
|
||||||
|
|
||||||
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
||||||
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
|
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
|
||||||
|
@ -64,8 +64,8 @@ flag_1 = False
|
||||||
p = Parameter(x, name="weight")
|
p = Parameter(x, name="weight")
|
||||||
a = np.ones((2, 2))
|
a = np.ones((2, 2))
|
||||||
|
|
||||||
forward_net = Net()
|
forward_net = FirstInputTupleNet()
|
||||||
grad_net = GradNet(forward_net)
|
grad_all_inputs_net = GradNet(forward_net, get_all=True)
|
||||||
|
|
||||||
|
|
||||||
def test_outermost_net_inputs_including_non_tensor():
|
def test_outermost_net_inputs_including_non_tensor():
|
||||||
|
@ -74,13 +74,31 @@ def test_outermost_net_inputs_including_non_tensor():
|
||||||
|
|
||||||
|
|
||||||
def test_grad_net_inputs_including_non_tensor():
|
def test_grad_net_inputs_including_non_tensor():
|
||||||
grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
|
assert len(grad_all_inputs_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)) == 2
|
||||||
grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)
|
assert len(grad_all_inputs_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_grad_first_input_net():
|
||||||
|
class FirstInputTensorNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(FirstInputTensorNet, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, tensor_x, tuple_a, list_b, tensor_y, scalar, dict_c, flag):
|
||||||
|
if flag:
|
||||||
|
return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"]
|
||||||
|
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"]
|
||||||
|
|
||||||
|
grad_fist_input_tensor_net = GradNet(FirstInputTensorNet(), get_all=False)
|
||||||
|
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, sl, args_d0, flag_0)
|
||||||
|
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
|
||||||
|
|
||||||
|
grad_fist_input_tuple_net = GradNet(forward_net, get_all=False)
|
||||||
|
assert not grad_fist_input_tuple_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0)
|
||||||
|
|
||||||
|
|
||||||
def test_net_inputs_including_str():
|
def test_net_inputs_including_str():
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError) as err:
|
||||||
grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
||||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||||
|
@ -117,7 +135,7 @@ def test_outermost_net_pass_list_including_parameter():
|
||||||
|
|
||||||
def test_grad_net_pass_dict_including_parameter():
|
def test_grad_net_pass_dict_including_parameter():
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError) as err:
|
||||||
grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
|
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
|
||||||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||||
|
|
Loading…
Reference in New Issue