!9252 [MS][GPU][DynamicShape] int32 scalar support for UnsortedSegMax/Min + CI alarm fix + UnsortedSegmentSum validation fix (detailed)
From: @danishnxt Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chen
This commit is contained in:
commit
f6450a614b
|
@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
// Dynamic Mode
|
||||
// Dynamic Mode - registered for int32/int64 3rd input
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
|
|
@ -69,7 +69,11 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
|||
} else {
|
||||
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2";
|
||||
}
|
||||
|
||||
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
if (value_count.size() != 1) {
|
||||
MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: "
|
||||
<< value_count.size() << ".";
|
||||
}
|
||||
num_segments_ = output_shapes[0];
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
|
@ -117,7 +121,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
private:
|
||||
int num_segments_;
|
||||
int64_t num_segments_;
|
||||
size_t inner_size_;
|
||||
size_t outer_size_;
|
||||
size_t input_size_;
|
||||
|
|
|
@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
UnsortedSegmentMin,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMinGpuKernel, int)
|
||||
// Dynamic Mode
|
||||
// Dynamic Mode - registered for int32/int64 3rd input
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMinGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMinGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMinGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMinGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMinGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
|
|
@ -65,7 +65,11 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel {
|
|||
} else {
|
||||
MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2";
|
||||
}
|
||||
|
||||
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
if (value_count.size() != 1) {
|
||||
MS_LOG(ERROR) << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank: 1, got Rank: "
|
||||
<< value_count.size() << ".";
|
||||
}
|
||||
num_segments_ = output_shapes[0];
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
|
@ -113,7 +117,7 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
private:
|
||||
int num_segments_;
|
||||
int64_t num_segments_;
|
||||
size_t inner_size_;
|
||||
size_t outer_size_;
|
||||
size_t input_size_;
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include <limits>
|
||||
|
||||
template <typename T>
|
||||
__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
size_t inner_size, bool fp16_flag, T init_K, T *output) {
|
||||
__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) {
|
||||
if (fp16_flag) {
|
||||
init_K = __int2half_rd(-65504); // min value representable by float16
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream) {
|
||||
int size = (inner_size * KWARPSIZE * num_segments);
|
||||
bool fp16_flag = false;
|
||||
|
@ -71,9 +71,9 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num
|
|||
return;
|
||||
}
|
||||
|
||||
template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);
|
||||
|
|
|
@ -22,9 +22,8 @@
|
|||
|
||||
// Setting warp size to sync data across threads
|
||||
#define KWARPSIZE 32
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_
|
||||
|
|
|
@ -17,19 +17,19 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh"
|
||||
#include <limits>
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void max_val_init(T *init_val) {
|
||||
*init_val = std::numeric_limits<T>::max();
|
||||
}
|
||||
// Handle fp16 differently for assignment
|
||||
template<>
|
||||
template <>
|
||||
__device__ __forceinline__ void max_val_init(half *init_val) {
|
||||
*init_val = __int2half_rd(65504); // Max value for Half
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
size_t inner_size, T init_K, T *output) {
|
||||
__global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, T init_K, T *output) {
|
||||
max_val_init(&init_K);
|
||||
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size;
|
||||
t_idx += blockDim.x * gridDim.x) {
|
||||
|
@ -62,18 +62,18 @@ __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream) {
|
||||
int size = (inner_size * KWARPSIZE * num_segments);
|
||||
T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later
|
||||
T init_K = std::numeric_limits<T>::lowest(); // only init here - overwritten later
|
||||
UnsortedSegmentMin<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, segment_ids, num_segments, outer_size,
|
||||
inner_size, init_K, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMin<float>(const float *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, float *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMin<half>(const half *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, half *output, cudaStream_t stream);
|
||||
template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int num_segments,
|
||||
template void CalUnsortedSegmentMin<int>(const int *input, const int *segment_ids, const int64_t num_segments,
|
||||
size_t outer_size, size_t inner_size, int *output, cudaStream_t stream);
|
||||
|
|
|
@ -23,6 +23,6 @@
|
|||
// Setting warp size to sync data across threads
|
||||
#define KWARPSIZE 32
|
||||
template <typename T>
|
||||
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size,
|
||||
void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size,
|
||||
size_t inner_size, T *output, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MIN_H_
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class PadGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(0), output_size_(0), workspace_size_(0) {}
|
||||
PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(1), output_size_(1), workspace_size_(0) {}
|
||||
~PadGpuFwdKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -53,13 +53,11 @@ class PadGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
// check number of inputs -> should be 1
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
// check number of output -> should be 1
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output.";
|
||||
|
@ -67,8 +65,7 @@ class PadGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
shape_size_ = input_shape.size();
|
||||
// shape adjustement -> from 2d/3d to 4d to standardize
|
||||
if (shape_size_ == 4) {
|
||||
if (shape_size_ == 4) { // shape adjustement from 2d/3d to 4d
|
||||
} else if (shape_size_ == 3) {
|
||||
auto it = input_shape.begin();
|
||||
input_shape.insert(it, 1); // batch padding
|
||||
|
@ -87,8 +84,7 @@ class PadGpuFwdKernel : public GpuKernel {
|
|||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
return shape;
|
||||
});
|
||||
// shape adjustement -> from 2d/3d to 4d to standardize
|
||||
if (paddings.size() == 4) {
|
||||
if (paddings.size() == 4) { // shape adjustement from 2d/3d to 4d
|
||||
} else if (paddings.size() == 3) {
|
||||
auto it = paddings.begin();
|
||||
paddings.insert(it, 1, {0, 0}); // batch padding
|
||||
|
@ -96,13 +92,11 @@ class PadGpuFwdKernel : public GpuKernel {
|
|||
auto it = paddings.begin();
|
||||
paddings.insert(it, 2, {0, 0}); // channel padding
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
input_size_ *= sizeof(T);
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
temp = input_shape[i] + (paddings[i][0] + paddings[i][1]); // compute new dim size
|
||||
output_size_ *= temp;
|
||||
|
|
|
@ -227,10 +227,18 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
|
|||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_tensor);
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
} else {
|
||||
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
|
||||
}
|
||||
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
|
||||
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
} else {
|
||||
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum";
|
||||
}
|
||||
|
@ -300,10 +308,19 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
|
|||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_tensor);
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
} else {
|
||||
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
|
||||
}
|
||||
// num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
|
||||
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
} else {
|
||||
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax";
|
||||
}
|
||||
|
@ -368,10 +385,18 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri
|
|||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_tensor);
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
} else {
|
||||
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
|
||||
}
|
||||
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
|
||||
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
} else {
|
||||
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin";
|
||||
}
|
||||
|
|
|
@ -1850,8 +1850,10 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
|
||||
validator.check(f'rank of input_x', len(x_shp),
|
||||
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
||||
for i, value in enumerate(segment_ids_shp):
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||
if (not -1 in x_shp and not -1 in segment_ids_shp):
|
||||
# only validate when both shapes fully known
|
||||
for i, value in enumerate(segment_ids_shp):
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||
num_segments_v = num_segments['value']
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
|
||||
|
@ -1925,7 +1927,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
|
|||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
|
||||
if isinstance(num_segments_type, type(mstype.tensor)):
|
||||
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64],
|
||||
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
|
||||
self.name)
|
||||
else:
|
||||
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
|
||||
|
@ -1978,7 +1980,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
|
|||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
|
||||
if isinstance(num_segments_type, type(mstype.tensor)):
|
||||
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64],
|
||||
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
|
||||
self.name)
|
||||
else:
|
||||
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
|
||||
|
|
Loading…
Reference in New Issue