forked from mindspore-Ecosystem/mindspore
!9096 DynamicShapeOp int64 support and bug fix
From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosman
This commit is contained in:
commit
9b0ec824c4
|
@ -19,13 +19,28 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, int32_t, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, half, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, float, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicShapeGpuKernel, bool, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicShapeGpuKernel, int32_t, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicShapeGpuKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicShapeGpuKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicShapeGpuKernel, bool, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
class DynamicShapeGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicShapeGpuKernel() { ResetResource(); }
|
||||
|
@ -38,8 +38,8 @@ class DynamicShapeGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
int *output_device_address = GetDeviceAddress<int>(outputs, 0);
|
||||
size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(int);
|
||||
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
|
||||
size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(S);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
|
@ -58,9 +58,10 @@ class DynamicShapeGpuKernel : public GpuKernel {
|
|||
input_size_ = 1;
|
||||
for (const size_t &e : prev_node_output_shape_tmp) {
|
||||
input_size_ *= e;
|
||||
// shapes are Tensors with elements of type int32, but GetPrevNodeOutputInferShape returns vector of size_t,
|
||||
// so we use an int* for allocated output memory and cast to an int here, otherwise the memcpy will fail with a
|
||||
// silently.
|
||||
// shapes are Tensors with elements of type S (int32, or int64) but
|
||||
// GetPrevNodeOutputInferShape returns vector of size_t, so we use
|
||||
// an S* for allocated output memory and cast to an integral type here,
|
||||
// otherwise the memcpy will fail silently.
|
||||
prev_node_output_shape_.push_back(e);
|
||||
}
|
||||
|
||||
|
@ -83,13 +84,13 @@ class DynamicShapeGpuKernel : public GpuKernel {
|
|||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(output_size_ * sizeof(int));
|
||||
output_size_list_.push_back(output_size_ * sizeof(S));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
std::vector<int> prev_node_output_shape_;
|
||||
std::vector<S> prev_node_output_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
class DynamicShapeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DynamicShapeNet, self).__init__()
|
||||
self.convert_to_dynamic_shape_op = inner.GpuConvertToDynamicShape()
|
||||
self.dynamic_shape_op = P.DynamicShape()
|
||||
|
||||
def construct(self, x):
|
||||
x_dynamic_shape = self.convert_to_dynamic_shape_op(x)
|
||||
return self.dynamic_shape_op(x_dynamic_shape)
|
||||
|
||||
|
||||
def dynamic_shape(np_type):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
dynamic_shape_net = DynamicShapeNet()
|
||||
|
||||
shape = (1,)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (7,)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (1, 1)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (1, 7)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (3, 1)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (2, 4)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (1, 1, 1)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (1, 5, 3)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
shape = (2, 3, 1, 3, 1)
|
||||
x = Tensor(np.zeros(shape).astype(np_type))
|
||||
ms_out = dynamic_shape_net(x).asnumpy()
|
||||
expected = np.array(shape)
|
||||
np.testing.assert_array_equal(ms_out, expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_shape_int32():
|
||||
dynamic_shape(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_shape_float16():
|
||||
dynamic_shape(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_shape_float32():
|
||||
dynamic_shape(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_shape_bool():
|
||||
dynamic_shape(np.bool)
|
Loading…
Reference in New Issue