forked from mindspore-Ecosystem/mindspore
!3655 gpu support BroadcastTo kernel
Merge pull request !3655 from chenweifeng/broadcast_to
This commit is contained in:
commit
f1a39a0f72
|
@ -51,11 +51,12 @@ class BroadcastToGpuKernel : public GpuKernel {
|
|||
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
|
||||
}
|
||||
|
||||
for (int i = input_shapes.size() - 1; i >= 0; i--) {
|
||||
input_shape_[i] = input_shapes[i];
|
||||
size_t offset = output_shapes.size() - input_shapes.size();
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
input_shape_[i + offset] = input_shapes[i];
|
||||
}
|
||||
|
||||
for (int j = output_shapes.size() - 1; j >= 0; j--) {
|
||||
for (size_t j = 0; j < output_shapes.size(); j++) {
|
||||
output_shape_[j] = output_shapes[j];
|
||||
}
|
||||
|
||||
|
|
|
@ -38,3 +38,9 @@ def test_broadcast():
|
|||
output = P.BroadcastTo(shape)(Tensor(x1_np))
|
||||
expect = np.broadcast_to(x1_np, shape)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
x1_np = np.random.rand(4, 5).astype(np.float32)
|
||||
shape = (2, 3, 4, 5)
|
||||
output = P.BroadcastTo(shape)(Tensor(x1_np))
|
||||
expect = np.broadcast_to(x1_np, shape)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue