!27203 while inplace op for cpu backend
Merge pull request !27203 from zhuzhongrui/gmres
This commit is contained in:
commit
96f56b827e
|
@ -141,7 +141,11 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
template <typename T>
|
||||
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto ret = memcpy_s(x, outputs[0]->size, inputs[0]->addr, inputs[0]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
|
||||
}
|
||||
ComputeParams<T> params;
|
||||
params.x_ = x;
|
||||
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
|
@ -165,11 +169,6 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
|
|||
start += once_compute_size;
|
||||
}
|
||||
(void)common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
|
||||
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,12 +56,12 @@ class ScatterNdFunctorKernel : public GpuKernel {
|
|||
cudaMemcpyAsync(indices_stride, &out_strides_[0], indices_len, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in ScatterNdFunctorGpuFwdKernel::Launch.");
|
||||
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
|
||||
updates, input, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices,
|
||||
updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ def test_op1(dtype):
|
|||
Description: test cases for updating float values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
|
@ -50,10 +51,10 @@ def test_op1(dtype):
|
|||
update = Tensor(np.array([1.0, 2.2], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("x:\n", output.asnumpy())
|
||||
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
|
||||
assert np.allclose(output.asnumpy(),
|
||||
np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
|
@ -67,6 +68,7 @@ def test_op2(dtype):
|
|||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
|
@ -81,10 +83,10 @@ def test_op2(dtype):
|
|||
update = Tensor(np.array([9, 10, 11, 12], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("x:\n", output.asnumpy())
|
||||
expect = [1, 11, 3, 10, 9, 6, 7, 12]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
|
||||
assert np.allclose(output.asnumpy(),
|
||||
np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
|
@ -98,6 +100,7 @@ def test_op3(dtype):
|
|||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
|
@ -114,13 +117,13 @@ def test_op3(dtype):
|
|||
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
output = scatter_nd_update(indices, update)
|
||||
print("x:\n", output.asnumpy())
|
||||
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype))
|
||||
assert np.allclose(output.asnumpy(), np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -133,6 +136,7 @@ def test_op4(dtype):
|
|||
Description: test cases for updating single float value
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
|
@ -148,6 +152,5 @@ def test_op4(dtype):
|
|||
scatter_nd_update = ScatterNdUpdate()
|
||||
out = scatter_nd_update(indices, update)
|
||||
print("x:\n", out)
|
||||
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
|
||||
|
|
|
@ -92,21 +92,21 @@ def test_scatter_nd_func_input_updated():
|
|||
|
||||
# update
|
||||
net = TestScatterNdFuncNet("update", lock, inputx, indices, updates)
|
||||
net()
|
||||
output = net()
|
||||
expected = np.array([[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]])
|
||||
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# add
|
||||
net = TestScatterNdFuncNet("add", lock, inputx, indices, updates)
|
||||
net()
|
||||
output = net()
|
||||
expected = np.array([[0.9, 0.3, 3.6], [0.4, 2.7, -3.2]])
|
||||
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# sub
|
||||
net = TestScatterNdFuncNet("sub", lock, inputx, indices, updates)
|
||||
net()
|
||||
output = net()
|
||||
expected = np.array([[-1.1, 0.3, 3.6], [0.4, -1.7, -3.2]])
|
||||
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue