!26034 F.grad support sens_param and fix graph_mode bug

Merge pull request !26034 from zhang_sss/grad
This commit is contained in:
i-robot 2021-11-12 06:38:15 +00:00 committed by Gitee
commit 0f07408425
6 changed files with 56 additions and 96 deletions

View File

@ -366,6 +366,9 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std:
(void)BackPropagate(!build_formal_param);
}
// Return the gradient;
if (grad_position.size() == 0) {
MS_LOG(EXCEPTION) << "grad_position in F.grad is empty!";
}
SetOutput(weights, grad_position, grad_inputs, grad_weights);
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, has_sens_arg);
@ -1020,7 +1023,7 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
}
for (size_t i = 0; i < grad_list.size(); ++i) {
if (grad_list[i] >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index is exceed input size!";
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
}
auto input = cell_inputs_[grad_list[i]];
MS_EXCEPTION_IF_NULL(input);
@ -1078,9 +1081,9 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
tape_output = tape_->NewCNode(grad_inputs_list);
tape_output->set_abstract(grad_inputs_spec);
} else {
size_t index = 0;
if (pos_size == 1) {
index = grad_position[0];
size_t index = grad_position[0];
if (index >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
}
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {

View File

@ -446,20 +446,6 @@ class _Grad(GradOperation_):
fn.set_grad(False)
def __call__(self, fn, weights=None, grad_position=0):
if isinstance(grad_position, tuple):
for gp in grad_position:
if not isinstance(gp, int):
raise TypeError(f"For '_Grad', the element in 'grad_position' should be int, "
f"but got {type(gp).__name__}")
if gp < 0:
raise ValueError("The element in grad_position must be >= 0.")
elif isinstance(grad_position, int):
if grad_position < 0:
raise ValueError("grad_position must be >= 0.")
grad_position = (grad_position,)
else:
raise TypeError(f"For '_Grad', the 'grad_position' should be int or tuple, "
f"but got {type(grad_position).__name__}")
if self.grad_fn is not None and self.fn == fn:
return self.grad_fn
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position)

View File

@ -159,9 +159,30 @@ partial = P.Partial()
depend = P.Depend()
identity = P.identity()
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
@constexpr
def _convert_grad_position_type(grad_position):
"""Check and convert the type and size of grad position index."""
if isinstance(grad_position, tuple):
for gp in grad_position:
if not isinstance(gp, int):
raise TypeError(f"For 'F.grad', the element in 'grad_position' should be int, "
f"but got {type(gp).__name__}")
if gp < 0:
raise ValueError("The element in grad_position must be >= 0.")
elif isinstance(grad_position, int):
if grad_position < 0:
raise ValueError("grad_position must be >= 0.")
grad_position = (grad_position,)
else:
raise TypeError(f"For 'F.grad', the 'grad_position' should be int or tuple, "
f"but got {type(grad_position).__name__}")
return grad_position
def grad(fn, grad_position=0):
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
def grad(fn, grad_position=0, sens_param=False):
r"""
A wrapper function to generate the gradient function for the input function.
@ -169,10 +190,15 @@ def grad(fn, grad_position=0):
fn (Union(Cell, function)): Function to do GradOperation.
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
Returns:
Function, returns the gradient function for the input function or cell.
"""
grad_position = _convert_grad_position_type(grad_position)
if sens_param:
return grad_by_position_with_sens(fn, None, grad_position)
return grad_by_position(fn, None, grad_position)

View File

@ -40,12 +40,7 @@ class MultipleInputsMultipleOutputsNet(nn.Cell):
def construct(self, x, y, z):
return x**2 + y**2 + z**2, x*y*z
def function1(x, y, z):
return x**2 + y**2 + z**2, x*y*z
@ms_function
def function2(x, y, z):
def function(x, y, z):
return x**2 + y**2 + z**2, x*y*z
def iteration_grad_function(x, y, z):
@ -53,7 +48,7 @@ def iteration_grad_function(x, y, z):
@ms_function
def grad_warp_with_msfunction(x, y, z):
output = grad(function1, grad_position=(1, 2))(x, y, z)
output = grad(function)(x, y, z)
return output
@ -135,38 +130,19 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_function_graph():
def test_grad_function_with_sens_graph():
"""
Features: Function grad.
Description: Test F.grad with function in graph mode.
Description: Test F.grad with function setting sens_param in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
real_grad = grad(function1, grad_position=(1, 2))(x, y, z)
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_msfuntion_graph():
"""
Features: Function grad.
Description: Test F.grad with ms_function in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
real_grad = grad(function2, grad_position=(1, 2))(x, y, z)
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
@ -204,10 +180,6 @@ def test_grad_warp_with_msfunction_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z)
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())

View File

@ -41,11 +41,7 @@ class MultipleInputsMultipleOutputsNet(nn.Cell):
return x**2 + y**2 + z**2, x*y*z
def function1(x, y, z):
return x**2 + y**2 + z**2, x*y*z
@ms_function
def function2(x, y, z):
def function(x, y, z):
return x**2 + y**2 + z**2, x*y*z
def iteration_grad_function(x, y, z):
@ -53,7 +49,7 @@ def iteration_grad_function(x, y, z):
@ms_function
def grad_warp_with_msfunction(x, y, z):
output = grad(function1, grad_position=(1, 2))(x, y, z)
output = grad(function)(x, y, z)
return output
@ -135,38 +131,19 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_function_pynative():
def test_grad_function_with_sens_pynative():
"""
Features: Function grad.
Description: Test F.grad with function in pynative mode.
Description: Test F.grad with function setting sens_param in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
real_grad = grad(function1, grad_position=(1, 2))(x, y, z)
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_msfuntion_pynative():
"""
Features: Function grad.
Description: Test F.grad with ms_function in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
real_grad = grad(function2, grad_position=(1, 2))(x, y, z)
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
@ -204,10 +181,6 @@ def test_grad_warp_with_msfunction_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
real_grad = grad_warp_with_msfunction(x, y, z)
assert isinstance(real_grad, tuple)
assert len(real_grad) == 2
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())

View File

@ -98,7 +98,7 @@ class TestKPynative : public UT::Common {
GradPynativeOp(k_pynative_cell, c_node, args, out);
}
}
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{}, true, false, false,
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, true, false, false,
true);
return bprop_fg;
}