fix copyslice grad bug

This commit is contained in:
shaojunsong 2024-01-12 14:45:18 +08:00
parent a804015ed8
commit 0b9c3f0a40
4 changed files with 50 additions and 7 deletions

View File

@ -265,7 +265,7 @@ void UpdateStubTensor(const FrontendOpRunInfoPtr &op_run_info) {
}
KernelTaskType GetViewOpTaskType(const std::string &op_name) {
if (op_name == kCopyWithScileOpName) {
if (op_name == kCopyWithSliceOpName) {
return KernelTaskType::kCOPY_TASK;
}
return KernelTaskType::kNORMAL_VIEW_TASK;

View File

@ -1276,7 +1276,7 @@ static inline py::object SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> *
(void)slice_op_infos->emplace_back(broadcastto_op_info);
auto copy_op_info = std::make_shared<pynative::SliceOpInfo>();
copy_op_info->slice_op_name = kCopyWithScileOpName;
copy_op_info->slice_op_name = kCopyWithSliceOpName;
copy_op_info->data_indexs = {0, 1};
(void)slice_op_infos->emplace_back(copy_op_info);
ValuePtr rdata_value;
@ -1830,13 +1830,16 @@ TypeId GetStubAbsTypeId(const AbstractBasePtr &abs) {
}
}
bool EnableView(bool is_pack_node, const TypeId &type_id, const py::bool_ &is_ascend) {
bool EnableView(bool is_pack_node, const TypeId &type_id, const py::bool_ &is_ascend, bool is_setitem = false) {
if (is_pack_node || pynative::PyNativeExecutor::GetInstance()->grad_executor()->is_high_order_top_cell()) {
// 1. pack node will slice failed with view.
// 2. SelectView and CopyWithSlice has no kernel, can not enable view in high order cell.
return false;
}
// For setitem, the grad of CopyWithSlice is erroneous. If we are in setitem and requires grad, disable view.
if (is_setitem && pynative::PyNativeExecutor::GetInstance()->grad_executor()->RequiresGrad()) return false;
if (is_ascend && (type_id == kNumberTypeComplex128 || type_id == kNumberTypeFloat64)) {
// AsStrided and ViewCopy is not support Complex128 and Float64, disable view
return false;
@ -2187,13 +2190,13 @@ py::object TensorIndex::SetItemIndexInfo(const py::object &py_data, const py::ob
MS_EXCEPTION_IF_NULL(data_type);
const auto &type_id = GetStubAbsTypeId(abs);
if (EnableView(value_info.second, type_id, is_ascend)) {
if (EnableView(value_info.second, type_id, is_ascend, true)) {
data_value = value_info.first;
}
} else {
TensorPtr data = py_data.cast<TensorPtr>();
MS_EXCEPTION_IF_NULL(data);
if (EnableView(false, data->data_type(), is_ascend)) {
if (EnableView(false, data->data_type(), is_ascend, true)) {
data_value = data;
}
data_shape = data->shape();

View File

@ -174,7 +174,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kUnsortedSegmentMinDOpName = "UnsortedSegmentMinD";
constexpr auto kUpdateCacheOpName = "UpdateCache";
constexpr auto kBroadcastOpName = "Broadcast";
constexpr auto kCopyWithScileOpName = "CopyWithSlice";
constexpr auto kCopyWithSliceOpName = "CopyWithSlice";
} // namespace mindspore
#endif // MINDSPORE_CORE_BASE_ARRAY_OP_NAME_H_

View File

@ -16,7 +16,7 @@
import numpy as np
import pytest
from mindspore import Tensor, context
from mindspore import Tensor, context, ops
from mindspore.nn import Cell
from mindspore import dtype as mstype
@ -57,6 +57,46 @@ class NumpySetItemByList():
return x
class SetitemNet(Cell):
def construct(self, x, y):
z = x * 2
z[:, 1:] += y[:, 1:]
return z
class SetitemGradNet(Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True)
def construct(self, x, y):
gradient_func = self.grad_op(self.net)
return gradient_func(x, y)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_grad():
"""
Feature: Test setitem grad
Description: setitem should return correct grad
Expectation: success
"""
net = SetitemNet()
a = Tensor(np.random.randn(2, 2, 2, 2), mstype.float32)
b = Tensor(np.random.randn(2, 2, 2, 2), mstype.float32)
b = ops.zeros_like(b)
output = SetitemGradNet(net)(a, b)
x_grad = np.ones((2, 2, 2, 2), np.float32) * 2
y_grad = np.array([[[[0, 0], [0, 0]], [[1, 1], [1, 1]]], [[[0, 0], [0, 0]], [[1, 1], [1, 1]]]], np.float32)
assert np.array_equal(output[0].asnumpy(), x_grad)
assert np.array_equal(output[1].asnumpy(), y_grad)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training