!3655 gpu support BroadcastTo kernel

Merge pull request !3655 from chenweifeng/broadcast_to
This commit is contained in:
mindspore-ci-bot 2020-07-30 14:23:18 +08:00 committed by Gitee
commit f1a39a0f72
2 changed files with 10 additions and 3 deletions

View File

@ -51,11 +51,12 @@ class BroadcastToGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
}
for (int i = input_shapes.size() - 1; i >= 0; i--) {
input_shape_[i] = input_shapes[i];
size_t offset = output_shapes.size() - input_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
input_shape_[i + offset] = input_shapes[i];
}
for (int j = output_shapes.size() - 1; j >= 0; j--) {
for (size_t j = 0; j < output_shapes.size(); j++) {
output_shape_[j] = output_shapes[j];
}

View File

@ -38,3 +38,9 @@ def test_broadcast():
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)
x1_np = np.random.rand(4, 5).astype(np.float32)
shape = (2, 3, 4, 5)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)