!8861 [MS][GPU][DynamicShapeUpdate] Converting UnsortedSegmentMax from normal to dynamic shape op

From: @danishnxt
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-28 06:49:39 +08:00 committed by Gitee
commit 442217314d
7 changed files with 238 additions and 31 deletions

View File

@ -22,15 +22,35 @@ MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentMaxGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
UnsortedSegmentMaxGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
UnsortedSegmentMax,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentMaxGpuKernel, int)
// Dynamic Mode
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentMaxGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
UnsortedSegmentMaxGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentMaxGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -28,14 +28,7 @@ namespace kernel {
template <typename T>
class UnsortedSegmentMaxGpuKernel : public GpuKernel {
public:
UnsortedSegmentMaxGpuKernel()
: num_segments_(1),
inner_size_(1),
outer_size_(1),
input_size_(1),
segment_ids_size_(1),
output_size_(1),
is_null_input_(false) {}
UnsortedSegmentMaxGpuKernel() { ResetResource(); }
~UnsortedSegmentMaxGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -60,18 +53,24 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
if (is_null_input_) {
MS_LOG(WARNING) << "UnsortedSegmentMax input is null";
InitSizeLists();
return true;
}
auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation
auto segment_ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 3) {
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 3 - dynamic mode";
} else {
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2";
}
num_segments_ = output_shapes[0];
input_size_ = 1;
for (size_t i = 0; i < input_shapes.size(); i++) {
input_size_ *= input_shapes[i];
@ -97,6 +96,19 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
num_segments_ = 1;
inner_size_ = 1;
outer_size_ = 1;
input_size_ = 1;
segment_ids_size_ = 1;
output_size_ = 1;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));

View File

@ -113,6 +113,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -273,6 +273,74 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(segment_ids);
MS_EXCEPTION_IF_NULL(segment_ids->shape());
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax";
}
if (num_segments_value <= 0) {
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax";
}
shape.emplace_back(num_segments_value);
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
if (!op_is_dynamic) {
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
}
// is dynamic
ShapeVector min_shape;
ShapeVector max_shape;
min_shape.emplace_back(num_segments_value);
max_shape.emplace_back(num_segments_value);
// only run validation if shape values are known
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
bool ids_any_shape =
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
if (!x_any_shape && !ids_any_shape) {
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
}
}
ShapeVector x_shape_min;
ShapeVector x_shape_max;
x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

View File

@ -58,6 +58,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},

View File

@ -1981,7 +1981,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
return out
class UnsortedSegmentMax(PrimitiveWithInfer):
class UnsortedSegmentMax(PrimitiveWithCheck):
"""
Computes the maximum along segments of a tensor.
@ -2017,27 +2017,21 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
def __init__(self):
"""Initialize UnsortedSegmentMax"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
def __check__(self, x, segment_ids, num_segments):
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_positive_int(num_segments_v, "num_segments", self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': x_type,
'value': None}
return out
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64],
self.name)
else:
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
class UnsortedSegmentProd(PrimitiveWithInfer):

View File

@ -20,6 +20,7 @@ import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
@ -139,7 +140,7 @@ def test_3d_float32():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init():
context.set_context(device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
@ -202,3 +203,112 @@ def test_3d_single_init():
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# For testing Dynamic Shape operation
class UnsortedSegmentMaxDynNet(nn.Cell):
def __init__(self, num_segments):
super(UnsortedSegmentMaxDynNet, self).__init__()
self.unsorted_segment_max = P.UnsortedSegmentMax()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.num_segments = num_segments
def construct(self, data, ids):
dyn_data = self.gpu_convert_to_dynamic_shape(data)
dyn_ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_max(dyn_data, dyn_ids, self.num_segments)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float32_dyn():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 3
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn():
context.set_context(device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
num_segments = 6
net = UnsortedSegmentMaxDynNet(num_segments)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)