forked from mindspore-Ecosystem/mindspore
!8801 gpu op for testing dynamic shape
From: @peilin-wang Reviewed-by: @robingrosman Signed-off-by:
This commit is contained in:
commit
2e0981faec
|
@ -0,0 +1,66 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/gpu/other/gpu_convert_to_dynamic_shape_gpu_kernel.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, bool)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, half)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, float)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, int8_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, int16_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, int32_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, int64_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, uint8_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, uint16_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, uint32_t)
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
GpuConvertToDynamicShapeGpuKernel, uint64_t)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class GpuConvertToDynamicShapeGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
GpuConvertToDynamicShapeGpuKernel() { ResetResource(); }
|
||||||
|
~GpuConvertToDynamicShapeGpuKernel() override = default;
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
VARIABLE_NOT_USED(workspace);
|
||||||
|
T *input_device_address = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
cuda_stream_ptr_ = stream_ptr;
|
||||||
|
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output_device_address, input_device_address, input_size_ * sizeof(T),
|
||||||
|
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"Failed to copy gpu memory.");
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PostExecute() override {
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream_ptr_)),
|
||||||
|
"cudaStreamSynchronized failed");
|
||||||
|
|
||||||
|
std::vector<TypeId> output_types = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
|
||||||
|
std::vector<std::vector<size_t>> output_shapes = {input_shape_};
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, c_node_ptr_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_count != 1) {
|
||||||
|
MS_LOG(ERROR) << input_count << "inputs were provided, but GpuConvertToDynamicShapeGpuKernel exepects 1.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
|
for (const size_t &e : input_shape_) {
|
||||||
|
input_size_ *= e;
|
||||||
|
}
|
||||||
|
|
||||||
|
c_node_ptr_ = kernel_node;
|
||||||
|
|
||||||
|
InitSizeLists();
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResetResource() noexcept override {
|
||||||
|
c_node_ptr_ = nullptr;
|
||||||
|
cuda_stream_ptr_ = nullptr;
|
||||||
|
input_shape_.clear();
|
||||||
|
input_size_ = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override {
|
||||||
|
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||||
|
output_size_list_.push_back(input_size_ * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void *cuda_stream_ptr_;
|
||||||
|
CNodePtr c_node_ptr_;
|
||||||
|
std::vector<size_t> input_shape_;
|
||||||
|
size_t input_size_;
|
||||||
|
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
|
@ -249,6 +249,9 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: a tuple or list or dict.
|
// Inputs: a tuple or list or dict.
|
||||||
|
|
|
@ -526,5 +526,20 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
const std::string &op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
ShapeVector input_shape = input->shape()->shape();
|
||||||
|
int32_t input_rank = input_shape.size();
|
||||||
|
ShapeVector inferred_shape(input_rank, Shape::SHP_ANY);
|
||||||
|
ShapeVector min_shape = {1};
|
||||||
|
ShapeVector max_shape = input_shape;
|
||||||
|
|
||||||
|
ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
|
||||||
|
return std::make_shared<AbstractTensor>(input->element(), shape);
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -121,6 +121,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||||
// Debug
|
// Debug
|
||||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||||
|
// Dynamic shape testing
|
||||||
|
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}},
|
||||||
// SparseTensor
|
// SparseTensor
|
||||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||||
|
|
|
@ -271,6 +271,10 @@ inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("Tens
|
||||||
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||||
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||||
|
|
||||||
|
// Dynamic shape testing
|
||||||
|
inline const PrimitivePtr kPrimGpuConvertToDynamicShape = std::make_shared<Primitive>("GpuConvertToDynamicShape");
|
||||||
|
inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared<Primitive>("ErrorOnDynamicShapeInput");
|
||||||
|
|
||||||
// Other miscellaneous
|
// Other miscellaneous
|
||||||
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
|
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
|
||||||
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
|
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ..._checkparam import Rel
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ... import context
|
from ... import context
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
|
|
||||||
|
|
||||||
|
@ -666,3 +666,99 @@ class ConfusionMulGrad(PrimitiveWithInfer):
|
||||||
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
|
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
|
||||||
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
|
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
|
||||||
return input0_dtype, input1_dtype
|
return input0_dtype, input1_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class GpuConvertToDynamicShape(PrimitiveWithCheck):
|
||||||
|
"""
|
||||||
|
This op is used for dynamic shape testing. Its inferred shape will be unknown
|
||||||
|
during compile time, so that its output will appear to be dynamically shaped.
|
||||||
|
The input will not be altered in any way. Put this operator before the operator
|
||||||
|
being tested for dynamic shape support.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - The tensor used for testing.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **output** (Tensor) - Same shape, type and value as `input`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
|
||||||
|
>>> class TestDynamicShapeReshapeNet(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(TestDynamicShapeReshapeNet, self).__init__()
|
||||||
|
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||||
|
>>> # suppose we are testing Reshape op
|
||||||
|
>>> self.reshape = P.Reshape()
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, input, new_shape):
|
||||||
|
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
|
||||||
|
>>> reshaped_input = self.reshape(input, new_shape)
|
||||||
|
>>>
|
||||||
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
>>> input = Tensor(np.array([0, 1, 2, 3])
|
||||||
|
>>> new_shape = (2, 2)
|
||||||
|
>>> net = TestDynamicShapeReshapeNet()
|
||||||
|
>>> output = net(input, new_shape)
|
||||||
|
>>> print(output)
|
||||||
|
[[0, 1], [2, 3]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.init_prim_io_names(inputs=["input"], outputs=["output"])
|
||||||
|
|
||||||
|
def check_shape(self, input_shape):
|
||||||
|
validator.check("input_shape rank", len(input_shape), "", 0, Rel.GT, self.name)
|
||||||
|
|
||||||
|
def check_dtype(self, input_dtype):
|
||||||
|
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
|
||||||
|
|
||||||
|
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
This op is used for dynamic shape testing. The only purpose of this operator is
|
||||||
|
that it will throw a value error if the input is dynamically shaped.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - The tensor used for testing.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **output** (Tensor) - Same shape, type and value as `input`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
|
||||||
|
>>> class AssertDynamicShapeNet(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(AssertDynamicShapeNet, self).__init__()
|
||||||
|
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||||
|
>>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, input, new_shape):
|
||||||
|
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
|
||||||
|
>>> self.error_on_dynamic_shape_input(dynamic_shape_input)
|
||||||
|
>>>
|
||||||
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
>>> input = Tensor(np.array([0])
|
||||||
|
>>> net = TestDynamicShapeReshapeNet()
|
||||||
|
>>> output = net(input, new_shape)
|
||||||
|
ValueError: Input is dynamically shaped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.init_prim_io_names(inputs=["input"], outputs=["output"])
|
||||||
|
|
||||||
|
def infer_shape(self, input_shape):
|
||||||
|
shape = list(input_shape)
|
||||||
|
|
||||||
|
for dim in shape:
|
||||||
|
if dim == -1:
|
||||||
|
raise ValueError("Input is dynamically shaped.")
|
||||||
|
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
def infer_type(self, input_dtype):
|
||||||
|
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
|
||||||
|
return input_dtype
|
||||||
|
|
||||||
|
def infer_value(self, input_tensor):
|
||||||
|
return input_tensor
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_error_on_dynamic_shape_input_is_dynamic():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
error_on_dynamic_shape_input.infer_shape([-1])
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
error_on_dynamic_shape_input.infer_shape([1, 1, -1])
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
error_on_dynamic_shape_input.infer_shape([-1, 1, 1])
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
error_on_dynamic_shape_input.infer_shape([1, -1, 1])
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
error_on_dynamic_shape_input.infer_shape([-1, -1, -1])
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_error_on_dynamic_shape_input_not_dynamic():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
|
||||||
|
error_on_dynamic_shape_input([1])
|
||||||
|
error_on_dynamic_shape_input([1, 1])
|
||||||
|
error_on_dynamic_shape_input([23, 12, 9712])
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
|
# test to make sure this op actually generates a dynamically shaped output
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dyanamic_shape_confirm_dynamic():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
class AssertDynamicShapeNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(AssertDynamicShapeNet, self).__init__()
|
||||||
|
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||||
|
self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
output = self.gpu_convert_to_dynamic_shape(x)
|
||||||
|
self.error_on_dynamic_shape_input(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
assert_dynamic_shape_net = AssertDynamicShapeNet()
|
||||||
|
x = Tensor(np.array([0, 0, 0, 0]).astype(np.float32))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
assert_dynamic_shape_net(x)
|
||||||
|
assert "Input is dynamically shaped" in str(info.value)
|
||||||
|
|
||||||
|
def gpu_convert_to_dynamic_shape(x):
|
||||||
|
class GpuConvertToDynamicShapeNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(GpuConvertToDynamicShapeNet, self).__init__()
|
||||||
|
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.gpu_convert_to_dynamic_shape(x)
|
||||||
|
|
||||||
|
gpu_convert_to_dynamic_shape_net = GpuConvertToDynamicShapeNet()
|
||||||
|
return gpu_convert_to_dynamic_shape_net(Tensor(x)).asnumpy()
|
||||||
|
|
||||||
|
def gpu_convert_to_dynamic_shape_float(dtype):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
finfo = np.finfo(dtype)
|
||||||
|
float_min = finfo.min
|
||||||
|
float_max = finfo.max
|
||||||
|
x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype)
|
||||||
|
ms_out = gpu_convert_to_dynamic_shape(x)
|
||||||
|
np.testing.assert_array_equal(x, ms_out)
|
||||||
|
|
||||||
|
def gpu_convert_to_dynamic_shape_int(dtype):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
iinfo = np.iinfo(dtype)
|
||||||
|
int_min = iinfo.min
|
||||||
|
int_max = iinfo.max
|
||||||
|
x = np.random.uniform(low=int_min, high=int_max, size=12).astype(dtype)
|
||||||
|
ms_out = gpu_convert_to_dynamic_shape(x)
|
||||||
|
np.testing.assert_array_equal(x, ms_out)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_bool():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
x = np.random.choice([False, True], 12)
|
||||||
|
ms_out = gpu_convert_to_dynamic_shape(x)
|
||||||
|
np.testing.assert_array_equal(x, ms_out)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_float16():
|
||||||
|
gpu_convert_to_dynamic_shape_float(np.float16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_float32():
|
||||||
|
gpu_convert_to_dynamic_shape_float(np.float32)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_int8():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.int8)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_int16():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.int16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_int32():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.int32)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_int64():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.int64)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_uint8():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.uint8)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_uint16():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.uint16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_uint32():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.uint32)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_convert_to_dynamic_shape_uint64():
|
||||||
|
gpu_convert_to_dynamic_shape_int(np.uint64)
|
Loading…
Reference in New Issue