!8801 gpu op for testing dynamic shape

From: @peilin-wang
Reviewed-by: @robingrosman
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-20 05:15:52 +08:00 committed by Gitee
commit 2e0981faec
9 changed files with 502 additions and 1 deletions

View File

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/other/gpu_convert_to_dynamic_shape_gpu_kernel.h"
#include <cstdint>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
GpuConvertToDynamicShapeGpuKernel, bool)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GpuConvertToDynamicShapeGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GpuConvertToDynamicShapeGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
GpuConvertToDynamicShapeGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
GpuConvertToDynamicShapeGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GpuConvertToDynamicShapeGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GpuConvertToDynamicShapeGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
GpuConvertToDynamicShapeGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
GpuConvertToDynamicShapeGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
GpuConvertToDynamicShapeGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
GpuConvertToDynamicShapeGpuKernel, uint64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
public:
GpuConvertToDynamicShapeGpuKernel() { ResetResource(); }
~GpuConvertToDynamicShapeGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_device_address = GetDeviceAddress<T>(inputs, 0);
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
cuda_stream_ptr_ = stream_ptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output_device_address, input_device_address, input_size_ * sizeof(T),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy gpu memory.");
return true;
}
void PostExecute() override {
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream_ptr_)),
"cudaStreamSynchronized failed");
std::vector<TypeId> output_types = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
std::vector<std::vector<size_t>> output_shapes = {input_shape_};
AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, c_node_ptr_.get());
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(ERROR) << input_count << "inputs were provided, but GpuConvertToDynamicShapeGpuKernel exepects 1.";
return false;
}
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const size_t &e : input_shape_) {
input_size_ *= e;
}
c_node_ptr_ = kernel_node;
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
c_node_ptr_ = nullptr;
cuda_stream_ptr_ = nullptr;
input_shape_.clear();
input_size_ = 1;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
}
private:
void *cuda_stream_ptr_;
CNodePtr c_node_ptr_;
std::vector<size_t> input_shape_;
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H

View File

@ -249,6 +249,9 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

View File

@ -526,5 +526,20 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
return ret;
}
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
ShapeVector input_shape = input->shape()->shape();
int32_t input_rank = input_shape.size();
ShapeVector inferred_shape(input_rank, Shape::SHP_ANY);
ShapeVector min_shape = {1};
ShapeVector max_shape = input_shape;
ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
return std::make_shared<AbstractTensor>(input->element(), shape);
}
} // namespace abstract
} // namespace mindspore

View File

@ -121,6 +121,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug
{prim::kPrimDebug, {InferImplDebug, true}},
// Dynamic shape testing
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},

View File

@ -271,6 +271,10 @@ inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("Tens
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// Dynamic shape testing
inline const PrimitivePtr kPrimGpuConvertToDynamicShape = std::make_shared<Primitive>("GpuConvertToDynamicShape");
inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared<Primitive>("ErrorOnDynamicShapeInput");
// Other miscellaneous
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");

View File

@ -19,7 +19,7 @@ from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ... import context
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
@ -666,3 +666,99 @@ class ConfusionMulGrad(PrimitiveWithInfer):
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
return input0_dtype, input1_dtype
class GpuConvertToDynamicShape(PrimitiveWithCheck):
"""
This op is used for dynamic shape testing. Its inferred shape will be unknown
during compile time, so that its output will appear to be dynamically shaped.
The input will not be altered in any way. Put this operator before the operator
being tested for dynamic shape support.
Inputs:
- **input** (Tensor) - The tensor used for testing.
Outputs:
- **output** (Tensor) - Same shape, type and value as `input`.
Examples:
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
>>> class TestDynamicShapeReshapeNet(nn.Cell):
>>> def __init__(self):
>>> super(TestDynamicShapeReshapeNet, self).__init__()
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
>>> # suppose we are testing Reshape op
>>> self.reshape = P.Reshape()
>>>
>>> def construct(self, input, new_shape):
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
>>> reshaped_input = self.reshape(input, new_shape)
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
>>> input = Tensor(np.array([0, 1, 2, 3])
>>> new_shape = (2, 2)
>>> net = TestDynamicShapeReshapeNet()
>>> output = net(input, new_shape)
>>> print(output)
[[0, 1], [2, 3]
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=["input"], outputs=["output"])
def check_shape(self, input_shape):
validator.check("input_shape rank", len(input_shape), "", 0, Rel.GT, self.name)
def check_dtype(self, input_dtype):
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
"""
This op is used for dynamic shape testing. The only purpose of this operator is
that it will throw a value error if the input is dynamically shaped.
Inputs:
- **input** (Tensor) - The tensor used for testing.
Outputs:
- **output** (Tensor) - Same shape, type and value as `input`.
Examples:
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
>>> class AssertDynamicShapeNet(nn.Cell):
>>> def __init__(self):
>>> super(AssertDynamicShapeNet, self).__init__()
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
>>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
>>>
>>> def construct(self, input, new_shape):
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
>>> self.error_on_dynamic_shape_input(dynamic_shape_input)
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
>>> input = Tensor(np.array([0])
>>> net = TestDynamicShapeReshapeNet()
>>> output = net(input, new_shape)
ValueError: Input is dynamically shaped.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=["input"], outputs=["output"])
def infer_shape(self, input_shape):
shape = list(input_shape)
for dim in shape:
if dim == -1:
raise ValueError("Input is dynamically shaped.")
return input_shape
def infer_type(self, input_dtype):
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
return input_dtype
def infer_value(self, input_tensor):
return input_tensor

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore.ops.operations import _inner_ops as inner
import mindspore.context as context
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_error_on_dynamic_shape_input_is_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
with pytest.raises(ValueError) as info:
error_on_dynamic_shape_input.infer_shape([-1])
assert "Input is dynamically shaped" in str(info.value)
with pytest.raises(ValueError) as info:
error_on_dynamic_shape_input.infer_shape([1, 1, -1])
assert "Input is dynamically shaped" in str(info.value)
with pytest.raises(ValueError) as info:
error_on_dynamic_shape_input.infer_shape([-1, 1, 1])
assert "Input is dynamically shaped" in str(info.value)
with pytest.raises(ValueError) as info:
error_on_dynamic_shape_input.infer_shape([1, -1, 1])
assert "Input is dynamically shaped" in str(info.value)
with pytest.raises(ValueError) as info:
error_on_dynamic_shape_input.infer_shape([-1, -1, -1])
assert "Input is dynamically shaped" in str(info.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_error_on_dynamic_shape_input_not_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
error_on_dynamic_shape_input([1])
error_on_dynamic_shape_input([1, 1])
error_on_dynamic_shape_input([23, 12, 9712])

View File

@ -0,0 +1,152 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops.operations import _inner_ops as inner
import mindspore.nn as nn
import mindspore.context as context
# test to make sure this op actually generates a dynamically shaped output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dyanamic_shape_confirm_dynamic():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class AssertDynamicShapeNet(nn.Cell):
def __init__(self):
super(AssertDynamicShapeNet, self).__init__()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
def construct(self, x):
output = self.gpu_convert_to_dynamic_shape(x)
self.error_on_dynamic_shape_input(output)
return output
assert_dynamic_shape_net = AssertDynamicShapeNet()
x = Tensor(np.array([0, 0, 0, 0]).astype(np.float32))
with pytest.raises(ValueError) as info:
assert_dynamic_shape_net(x)
assert "Input is dynamically shaped" in str(info.value)
def gpu_convert_to_dynamic_shape(x):
class GpuConvertToDynamicShapeNet(nn.Cell):
def __init__(self):
super(GpuConvertToDynamicShapeNet, self).__init__()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
def construct(self, x):
return self.gpu_convert_to_dynamic_shape(x)
gpu_convert_to_dynamic_shape_net = GpuConvertToDynamicShapeNet()
return gpu_convert_to_dynamic_shape_net(Tensor(x)).asnumpy()
def gpu_convert_to_dynamic_shape_float(dtype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(0)
finfo = np.finfo(dtype)
float_min = finfo.min
float_max = finfo.max
x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype)
ms_out = gpu_convert_to_dynamic_shape(x)
np.testing.assert_array_equal(x, ms_out)
def gpu_convert_to_dynamic_shape_int(dtype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(0)
iinfo = np.iinfo(dtype)
int_min = iinfo.min
int_max = iinfo.max
x = np.random.uniform(low=int_min, high=int_max, size=12).astype(dtype)
ms_out = gpu_convert_to_dynamic_shape(x)
np.testing.assert_array_equal(x, ms_out)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_bool():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(0)
x = np.random.choice([False, True], 12)
ms_out = gpu_convert_to_dynamic_shape(x)
np.testing.assert_array_equal(x, ms_out)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_float16():
gpu_convert_to_dynamic_shape_float(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_float32():
gpu_convert_to_dynamic_shape_float(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_int8():
gpu_convert_to_dynamic_shape_int(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_int16():
gpu_convert_to_dynamic_shape_int(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_int32():
gpu_convert_to_dynamic_shape_int(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_int64():
gpu_convert_to_dynamic_shape_int(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_uint8():
gpu_convert_to_dynamic_shape_int(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_uint16():
gpu_convert_to_dynamic_shape_int(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_uint32():
gpu_convert_to_dynamic_shape_int(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_uint64():
gpu_convert_to_dynamic_shape_int(np.uint64)