!27203 while inplace op for cpu backend

Merge pull request !27203 from zhuzhongrui/gmres
This commit is contained in:
i-robot 2021-12-07 06:59:34 +00:00 committed by Gitee
commit 96f56b827e
4 changed files with 26 additions and 24 deletions

View File

@ -141,7 +141,11 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
template <typename T>
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto x = reinterpret_cast<T *>(outputs[0]->addr);
auto ret = memcpy_s(x, outputs[0]->size, inputs[0]->addr, inputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
}
ComputeParams<T> params;
params.x_ = x;
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr);
@ -165,11 +169,6 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
start += once_compute_size;
}
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -56,12 +56,12 @@ class ScatterNdFunctorKernel : public GpuKernel {
cudaMemcpyAsync(indices_stride, &out_strides_[0], indices_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in ScatterNdFunctorGpuFwdKernel::Launch.");
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
updates, input, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -36,6 +36,7 @@ def test_op1(dtype):
Description: test cases for updating float values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@ -50,10 +51,10 @@ def test_op1(dtype):
update = Tensor(np.array([1.0, 2.2], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
assert np.allclose(output.asnumpy(),
np.array(expect, dtype=dtype))
@ -67,6 +68,7 @@ def test_op2(dtype):
Description: test cases for updating int values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@ -81,10 +83,10 @@ def test_op2(dtype):
update = Tensor(np.array([9, 10, 11, 12], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
assert np.allclose(output.asnumpy(),
np.array(expect, dtype=dtype))
@ -98,6 +100,7 @@ def test_op3(dtype):
Description: test cases for updating int values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@ -114,13 +117,13 @@ def test_op3(dtype):
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
output = scatter_nd_update(indices, update)
print("x:\n", output.asnumpy())
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype))
assert np.allclose(output.asnumpy(), np.array(expect, dtype=dtype))
@pytest.mark.level0
@ -133,6 +136,7 @@ def test_op4(dtype):
Description: test cases for updating single float value
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
@ -148,6 +152,5 @@ def test_op4(dtype):
scatter_nd_update = ScatterNdUpdate()
out = scatter_nd_update(indices, update)
print("x:\n", out)
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))

View File

@ -92,21 +92,21 @@ def test_scatter_nd_func_input_updated():
# update
net = TestScatterNdFuncNet("update", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# add
net = TestScatterNdFuncNet("add", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[0.9, 0.3, 3.6], [0.4, 2.7, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# sub
net = TestScatterNdFuncNet("sub", lock, inputx, indices, updates)
net()
output = net()
expected = np.array([[-1.1, 0.3, 3.6], [0.4, -1.7, -3.2]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0