forked from mindspore-Ecosystem/mindspore
!26034 F.grad support sens_param and fix graph_mode bug
Merge pull request !26034 from zhang_sss/grad
This commit is contained in:
commit
0f07408425
|
@ -366,6 +366,9 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, const std:
|
|||
(void)BackPropagate(!build_formal_param);
|
||||
}
|
||||
// Return the gradient;
|
||||
if (grad_position.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "grad_position in F.grad is empty!";
|
||||
}
|
||||
SetOutput(weights, grad_position, grad_inputs, grad_weights);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, has_sens_arg);
|
||||
|
@ -1020,7 +1023,7 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
|
|||
}
|
||||
for (size_t i = 0; i < grad_list.size(); ++i) {
|
||||
if (grad_list[i] >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index is exceed input size!";
|
||||
MS_LOG(EXCEPTION) << "Position index " << grad_list[i] << " is exceed input size!";
|
||||
}
|
||||
auto input = cell_inputs_[grad_list[i]];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -1078,9 +1081,9 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vect
|
|||
tape_output = tape_->NewCNode(grad_inputs_list);
|
||||
tape_output->set_abstract(grad_inputs_spec);
|
||||
} else {
|
||||
size_t index = 0;
|
||||
if (pos_size == 1) {
|
||||
index = grad_position[0];
|
||||
size_t index = grad_position[0];
|
||||
if (index >= cell_inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size!";
|
||||
}
|
||||
auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[index]);
|
||||
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
|
|
|
@ -446,20 +446,6 @@ class _Grad(GradOperation_):
|
|||
fn.set_grad(False)
|
||||
|
||||
def __call__(self, fn, weights=None, grad_position=0):
|
||||
if isinstance(grad_position, tuple):
|
||||
for gp in grad_position:
|
||||
if not isinstance(gp, int):
|
||||
raise TypeError(f"For '_Grad', the element in 'grad_position' should be int, "
|
||||
f"but got {type(gp).__name__}")
|
||||
if gp < 0:
|
||||
raise ValueError("The element in grad_position must be >= 0.")
|
||||
elif isinstance(grad_position, int):
|
||||
if grad_position < 0:
|
||||
raise ValueError("grad_position must be >= 0.")
|
||||
grad_position = (grad_position,)
|
||||
else:
|
||||
raise TypeError(f"For '_Grad', the 'grad_position' should be int or tuple, "
|
||||
f"but got {type(grad_position).__name__}")
|
||||
if self.grad_fn is not None and self.fn == fn:
|
||||
return self.grad_fn
|
||||
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position)
|
||||
|
|
|
@ -159,9 +159,30 @@ partial = P.Partial()
|
|||
depend = P.Depend()
|
||||
identity = P.identity()
|
||||
|
||||
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
|
||||
@constexpr
|
||||
def _convert_grad_position_type(grad_position):
|
||||
"""Check and convert the type and size of grad position index."""
|
||||
if isinstance(grad_position, tuple):
|
||||
for gp in grad_position:
|
||||
if not isinstance(gp, int):
|
||||
raise TypeError(f"For 'F.grad', the element in 'grad_position' should be int, "
|
||||
f"but got {type(gp).__name__}")
|
||||
if gp < 0:
|
||||
raise ValueError("The element in grad_position must be >= 0.")
|
||||
elif isinstance(grad_position, int):
|
||||
if grad_position < 0:
|
||||
raise ValueError("grad_position must be >= 0.")
|
||||
grad_position = (grad_position,)
|
||||
else:
|
||||
raise TypeError(f"For 'F.grad', the 'grad_position' should be int or tuple, "
|
||||
f"but got {type(grad_position).__name__}")
|
||||
return grad_position
|
||||
|
||||
def grad(fn, grad_position=0):
|
||||
|
||||
grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
|
||||
grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
|
||||
|
||||
def grad(fn, grad_position=0, sens_param=False):
|
||||
r"""
|
||||
A wrapper function to generate the gradient function for the input function.
|
||||
|
||||
|
@ -169,10 +190,15 @@ def grad(fn, grad_position=0):
|
|||
fn (Union(Cell, function)): Function to do GradOperation.
|
||||
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
||||
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
||||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
|
||||
|
||||
Returns:
|
||||
Function, returns the gradient function for the input function or cell.
|
||||
"""
|
||||
grad_position = _convert_grad_position_type(grad_position)
|
||||
if sens_param:
|
||||
return grad_by_position_with_sens(fn, None, grad_position)
|
||||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
|
||||
|
|
|
@ -40,12 +40,7 @@ class MultipleInputsMultipleOutputsNet(nn.Cell):
|
|||
def construct(self, x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
|
||||
def function1(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
@ms_function
|
||||
def function2(x, y, z):
|
||||
def function(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
def iteration_grad_function(x, y, z):
|
||||
|
@ -53,7 +48,7 @@ def iteration_grad_function(x, y, z):
|
|||
|
||||
@ms_function
|
||||
def grad_warp_with_msfunction(x, y, z):
|
||||
output = grad(function1, grad_position=(1, 2))(x, y, z)
|
||||
output = grad(function)(x, y, z)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -135,38 +130,19 @@ def test_grad_multiple_inputs_multiple_outputs_cell_graph():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_function_graph():
|
||||
def test_grad_function_with_sens_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function in graph mode.
|
||||
Description: Test F.grad with function setting sens_param in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
real_grad = grad(function1, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_msfuntion_graph():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with ms_function in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
real_grad = grad(function2, grad_position=(1, 2))(x, y, z)
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
|
||||
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
|
@ -204,10 +180,6 @@ def test_grad_warp_with_msfunction_graph():
|
|||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
|
|
@ -41,11 +41,7 @@ class MultipleInputsMultipleOutputsNet(nn.Cell):
|
|||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
|
||||
def function1(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
@ms_function
|
||||
def function2(x, y, z):
|
||||
def function(x, y, z):
|
||||
return x**2 + y**2 + z**2, x*y*z
|
||||
|
||||
def iteration_grad_function(x, y, z):
|
||||
|
@ -53,7 +49,7 @@ def iteration_grad_function(x, y, z):
|
|||
|
||||
@ms_function
|
||||
def grad_warp_with_msfunction(x, y, z):
|
||||
output = grad(function1, grad_position=(1, 2))(x, y, z)
|
||||
output = grad(function)(x, y, z)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -135,38 +131,19 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_function_pynative():
|
||||
def test_grad_function_with_sens_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with function in pynative mode.
|
||||
Description: Test F.grad with function setting sens_param in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
real_grad = grad(function1, grad_position=(1, 2))(x, y, z)
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_msfuntion_pynative():
|
||||
"""
|
||||
Features: Function grad.
|
||||
Description: Test F.grad with ms_function in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
real_grad = grad(function2, grad_position=(1, 2))(x, y, z)
|
||||
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[4, 36], [26, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[2, 36], [14, 6]]).astype(np.float32))
|
||||
real_grad = grad(function, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
|
@ -204,10 +181,6 @@ def test_grad_warp_with_msfunction_pynative():
|
|||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
|
||||
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
|
||||
expect_grad1 = Tensor(np.array([[-4, 12], [13, 0]]).astype(np.float32))
|
||||
expect_grad2 = Tensor(np.array([[-2, 12], [7, 6]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[2, 13], [1, 6]]).astype(np.float32))
|
||||
real_grad = grad_warp_with_msfunction(x, y, z)
|
||||
assert isinstance(real_grad, tuple)
|
||||
assert len(real_grad) == 2
|
||||
assert np.allclose(real_grad[0].asnumpy(), expect_grad1.asnumpy())
|
||||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
assert np.allclose(real_grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
|
|
@ -98,7 +98,7 @@ class TestKPynative : public UT::Common {
|
|||
GradPynativeOp(k_pynative_cell, c_node, args, out);
|
||||
}
|
||||
}
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{}, true, false, false,
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, true, false, false,
|
||||
true);
|
||||
return bprop_fg;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue