!64340 【bugfix】Fix setitem with copyslice
Merge pull request !64340 from shaojunsong/fix/copyslice
This commit is contained in:
commit
2eb43ae04c
|
@ -265,7 +265,7 @@ void UpdateStubTensor(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
}
|
||||
|
||||
KernelTaskType GetViewOpTaskType(const std::string &op_name) {
|
||||
if (op_name == kCopyWithScileOpName) {
|
||||
if (op_name == kCopyWithSliceOpName) {
|
||||
return KernelTaskType::kCOPY_TASK;
|
||||
}
|
||||
return KernelTaskType::kNORMAL_VIEW_TASK;
|
||||
|
|
|
@ -1276,7 +1276,7 @@ static inline py::object SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> *
|
|||
(void)slice_op_infos->emplace_back(broadcastto_op_info);
|
||||
|
||||
auto copy_op_info = std::make_shared<pynative::SliceOpInfo>();
|
||||
copy_op_info->slice_op_name = kCopyWithScileOpName;
|
||||
copy_op_info->slice_op_name = kCopyWithSliceOpName;
|
||||
copy_op_info->data_indexs = {0, 1};
|
||||
(void)slice_op_infos->emplace_back(copy_op_info);
|
||||
ValuePtr rdata_value;
|
||||
|
@ -1830,13 +1830,16 @@ TypeId GetStubAbsTypeId(const AbstractBasePtr &abs) {
|
|||
}
|
||||
}
|
||||
|
||||
bool EnableView(bool is_pack_node, const TypeId &type_id, const py::bool_ &is_ascend) {
|
||||
bool EnableView(bool is_pack_node, const TypeId &type_id, const py::bool_ &is_ascend, bool is_setitem = false) {
|
||||
if (is_pack_node || pynative::PyNativeExecutor::GetInstance()->grad_executor()->is_high_order_top_cell()) {
|
||||
// 1. pack node will slice failed with view.
|
||||
// 2. SelectView and CopyWithSlice has no kernel, can not enable view in high order cell.
|
||||
return false;
|
||||
}
|
||||
|
||||
// For setitem, the grad of CopyWithSlice is erroneous. If we are in setitem and requires grad, disable view.
|
||||
if (is_setitem && pynative::PyNativeExecutor::GetInstance()->grad_executor()->RequiresGrad()) return false;
|
||||
|
||||
if (is_ascend && (type_id == kNumberTypeComplex128 || type_id == kNumberTypeFloat64)) {
|
||||
// AsStrided and ViewCopy is not support Complex128 and Float64, disable view
|
||||
return false;
|
||||
|
@ -2187,13 +2190,13 @@ py::object TensorIndex::SetItemIndexInfo(const py::object &py_data, const py::ob
|
|||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
|
||||
const auto &type_id = GetStubAbsTypeId(abs);
|
||||
if (EnableView(value_info.second, type_id, is_ascend)) {
|
||||
if (EnableView(value_info.second, type_id, is_ascend, true)) {
|
||||
data_value = value_info.first;
|
||||
}
|
||||
} else {
|
||||
TensorPtr data = py_data.cast<TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (EnableView(false, data->data_type(), is_ascend)) {
|
||||
if (EnableView(false, data->data_type(), is_ascend, true)) {
|
||||
data_value = data;
|
||||
}
|
||||
data_shape = data->shape();
|
||||
|
|
|
@ -174,7 +174,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
|
|||
constexpr auto kUnsortedSegmentMinDOpName = "UnsortedSegmentMinD";
|
||||
constexpr auto kUpdateCacheOpName = "UpdateCache";
|
||||
constexpr auto kBroadcastOpName = "Broadcast";
|
||||
constexpr auto kCopyWithScileOpName = "CopyWithSlice";
|
||||
constexpr auto kCopyWithSliceOpName = "CopyWithSlice";
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_BASE_ARRAY_OP_NAME_H_
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import Tensor, context, ops
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
@ -57,6 +57,46 @@ class NumpySetItemByList():
|
|||
return x
|
||||
|
||||
|
||||
class SetitemNet(Cell):
|
||||
def construct(self, x, y):
|
||||
z = x * 2
|
||||
z[:, 1:] += y[:, 1:]
|
||||
return z
|
||||
|
||||
|
||||
class SetitemGradNet(Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, x, y):
|
||||
gradient_func = self.grad_op(self.net)
|
||||
return gradient_func(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_setitem_grad():
|
||||
"""
|
||||
Feature: Test setitem grad
|
||||
Description: setitem should return correct grad
|
||||
Expectation: success
|
||||
"""
|
||||
net = SetitemNet()
|
||||
a = Tensor(np.random.randn(2, 2, 2, 2), mstype.float32)
|
||||
b = Tensor(np.random.randn(2, 2, 2, 2), mstype.float32)
|
||||
b = ops.zeros_like(b)
|
||||
output = SetitemGradNet(net)(a, b)
|
||||
x_grad = np.ones((2, 2, 2, 2), np.float32) * 2
|
||||
y_grad = np.array([[[[0, 0], [0, 0]], [[1, 1], [1, 1]]], [[[0, 0], [0, 0]], [[1, 1], [1, 1]]]], np.float32)
|
||||
assert np.array_equal(output[0].asnumpy(), x_grad)
|
||||
assert np.array_equal(output[1].asnumpy(), y_grad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue