!33227 [assistant][ops] Add SegmentMean, SegmentProd

Merge pull request !33227 from 王乐/SegmentMean
This commit is contained in:
i-robot 2022-06-18 03:43:43 +00:00 committed by Gitee
commit 5bd1964bf7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 1383 additions and 0 deletions

View File

@ -36,6 +36,8 @@ constexpr auto kFractionalAvgPoolGradOpName = "FractionalAvgPoolGrad";
constexpr auto kSegmentMaxOpName = "SegmentMax";
constexpr auto kSegmentMinOpName = "SegmentMin";
constexpr auto kSegmentSumOpName = "SegmentSum";
constexpr auto kSegmentMeanOpName = "SegmentMean";
constexpr auto kSegmentProdOpName = "SegmentProd";
constexpr auto kConcatOpName = "Concat";
constexpr auto kListDiffOpName = "ListDiff";
constexpr auto kUniqueOpName = "Unique";
@ -847,6 +849,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName,
kSegmentMaxOpName,
kSegmentMinOpName,
kSegmentSumOpName,
kSegmentMeanOpName,
kSegmentProdOpName,
kNonZeroOpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,

View File

@ -0,0 +1,310 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/segment_mean_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace {
const size_t kSegmentsThreshold = 2 * 1024;
const size_t kDataSizeThreshold = 2 * 1024;
#define SEGMENTMEAN_COMPUTE_CASE(DTYPE, TYPE1, TYPE2) \
case (DTYPE): { \
ret = LaunchKernel<TYPE1, TYPE2>(inputs, outputs); \
break; \
}
#define SEGMENTMEAN_COMPUTE_CASE_CP(DTYPE, TYPE1, TYPE2) \
case (DTYPE): { \
ret = LaunchKernelComplex<TYPE1, TYPE2>(inputs, outputs); \
break; \
}
} // namespace
namespace mindspore {
namespace kernel {
template <typename T>
T ComplexDiv(T sum, size_t num) {
if (num != 0) {
T res;
auto real = sum.real();
auto imag = sum.imag();
res.real(real / num);
res.imag(imag / num);
return res;
} else {
MS_EXCEPTION(ValueError) << "For 'SegmentMean', divisor can not be 0.";
}
}
void SegmentMeanCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
input_x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
input_x_num_ = CPUKernelUtils::CalcElementNum(input_x_shape_);
segment_ids_num_ = CPUKernelUtils::CalcElementNum(segment_ids_shape_);
output_num_ = CPUKernelUtils::CalcElementNum(output_shape_);
}
std::vector<KernelAttr> SegmentMeanCPUKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)};
return support_list;
}
bool SegmentMeanCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ret = true;
switch (segment_ids_dtype_) {
case kNumberTypeInt32: {
switch (input_x_dtype_) {
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int32_t)
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt8, int8_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt16, int16_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt64, int64_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat16, float16, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat32, float, int32_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat64, double, int32_t)
default:
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported input_x data type: " << input_x_dtype_;
}
break;
}
case kNumberTypeInt64: {
switch (input_x_dtype_) {
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int64_t)
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt8, int8_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt16, int16_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt32, int32_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat16, float16, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat32, float, int64_t)
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat64, double, int64_t)
default:
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported input_x data type: " << input_x_dtype_;
}
break;
}
default:
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported segment_ids data type: " << segment_ids_dtype_;
}
return ret;
}
template <typename T1, typename T2>
bool SegmentMeanCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
std::vector<int64_t> segments;
int64_t seg_tmp = 1;
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
seg_tmp++;
} else {
segments.push_back(seg_tmp);
seg_tmp = 1;
}
const size_t last_loc = 2;
if (i == segment_ids_num_ - last_loc) {
segments.push_back(seg_tmp);
}
}
if (segment_ids_num_ == 1) {
segments.push_back(seg_tmp);
}
for (size_t i = 0; i < output_num_; ++i) {
output_data_addr[i] = static_cast<T1>(0);
}
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
const size_t num_segments = segments.size();
if (num_segments < kSegmentsThreshold) {
for (size_t i = 0; i < num_segments; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; ++j) {
size_t mean_init_addr = input_addr_base + j;
T1 sum_value = input_x_data_addr[mean_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = mean_init_addr + k * num_compare_per;
sum_value += input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = sum_value / static_cast<T1>(count);
}
};
if (num_compare_per < kDataSizeThreshold) {
task(0, num_compare_per);
} else {
CPUKernelUtils::ParallelFor(task, num_compare_per);
}
}
} else {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
for (size_t j = 0; j < num_compare_per; ++j) {
size_t mean_init_addr = input_addr_base + j;
T1 sum_value = input_x_data_addr[mean_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = mean_init_addr + k * num_compare_per;
sum_value += input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = sum_value / static_cast<T1>(count);
}
}
};
CPUKernelUtils::ParallelFor(task, num_segments);
}
return true;
}
template <typename T1, typename T2>
bool SegmentMeanCPUKernelMod::LaunchKernelComplex(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
std::vector<int64_t> segments;
int64_t seg_tmp = 1;
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
seg_tmp++;
} else {
segments.push_back(seg_tmp);
seg_tmp = 1;
}
const size_t last_loc = 2;
if (i == segment_ids_num_ - last_loc) {
segments.push_back(seg_tmp);
}
}
if (segment_ids_num_ == 1) {
segments.push_back(seg_tmp);
}
for (size_t i = 0; i < output_num_; ++i) {
output_data_addr[i] = static_cast<T1>(0);
}
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
const size_t num_segments = segments.size();
if (num_segments < kSegmentsThreshold) {
for (size_t i = 0; i < num_segments; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; ++j) {
size_t mean_init_addr = input_addr_base + j;
T1 sum_value = input_x_data_addr[mean_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = mean_init_addr + k * num_compare_per;
sum_value += input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = ComplexDiv(sum_value, count);
}
};
if (num_compare_per < kDataSizeThreshold) {
task(0, num_compare_per);
} else {
CPUKernelUtils::ParallelFor(task, num_compare_per);
}
}
} else {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
for (size_t j = 0; j < num_compare_per; ++j) {
size_t mean_init_addr = input_addr_base + j;
T1 sum_value = input_x_data_addr[mean_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = mean_init_addr + k * num_compare_per;
sum_value += input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = ComplexDiv(sum_value, count);
}
}
};
CPUKernelUtils::ParallelFor(task, num_segments);
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SegmentMean, SegmentMeanCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include <complex>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SegmentMeanCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SegmentMeanCPUKernelMod() = default;
~SegmentMeanCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T1, typename T2>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T1, typename T2>
bool LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::vector<size_t> input_x_shape_;
std::vector<size_t> segment_ids_shape_;
std::vector<size_t> output_shape_;
size_t input_x_num_;
size_t segment_ids_num_;
size_t output_num_;
TypeId input_x_dtype_{kTypeUnknown};
TypeId segment_ids_dtype_{kTypeUnknown};
TypeId output_dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_

View File

@ -0,0 +1,309 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/segment_prod_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace {
const size_t kSegmentsThreshold = 2 * 1024;
const size_t kDataSizeThreshold = 2 * 1024;
template <typename T>
T ComputeProd(const T num_1, const T num_2) {
T res;
auto a = num_1.real();
auto b = num_1.imag();
auto x = num_2.real();
auto y = num_2.imag();
auto real_res = a * x - b * y;
auto imag_res = b * x + a * y;
res.real(real_res);
res.imag(imag_res);
return res;
}
#define SEGMENTPROD_COMPUTE_CASE(DTYPE, TYPE1, TYPE2) \
case (DTYPE): { \
ret = LaunchKernel<TYPE1, TYPE2>(inputs, outputs); \
break; \
}
#define SEGMENTPROD_COMPUTE_CASE_CP(DTYPE, TYPE1, TYPE2) \
case (DTYPE): { \
ret = LaunchKernel_CP<TYPE1, TYPE2>(inputs, outputs); \
break; \
}
} // namespace
namespace mindspore {
namespace kernel {
void SegmentProdCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
input_x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
input_x_num_ = CPUKernelUtils::CalcElementNum(input_x_shape_);
segment_ids_num_ = CPUKernelUtils::CalcElementNum(segment_ids_shape_);
output_num_ = CPUKernelUtils::CalcElementNum(output_shape_);
}
std::vector<KernelAttr> SegmentProdCPUKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)};
return support_list;
}
bool SegmentProdCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ret = true;
switch (segment_ids_dtype_) {
case kNumberTypeInt32: {
switch (input_x_dtype_) {
SEGMENTPROD_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int32_t)
SEGMENTPROD_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt8, int8_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt16, int16_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt64, int64_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat16, float16, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat32, float, int32_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat64, double, int32_t)
default:
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported input_x data type: " << input_x_dtype_;
}
break;
}
case kNumberTypeInt64: {
switch (input_x_dtype_) {
SEGMENTPROD_COMPUTE_CASE(kNumberTypeComplex64, std::complex<float>, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeComplex128, std::complex<double>, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt8, int8_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt16, int16_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt32, int32_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat16, float16, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat32, float, int64_t)
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat64, double, int64_t)
default:
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported input_x data type: " << input_x_dtype_;
}
break;
}
default:
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported segment_ids data type: " << segment_ids_dtype_;
}
return ret;
}
template <typename T1, typename T2>
bool SegmentProdCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
std::vector<int64_t> segments;
int64_t seg_tmp = 1;
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
seg_tmp++;
} else {
segments.push_back(seg_tmp);
seg_tmp = 1;
}
const size_t last_loc = 2;
if (i == segment_ids_num_ - last_loc) {
segments.push_back(seg_tmp);
}
}
if (segment_ids_num_ == 1) {
segments.push_back(seg_tmp);
}
for (size_t i = 0; i < output_num_; ++i) {
output_data_addr[i] = static_cast<T1>(1);
}
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
const size_t num_segments = segments.size();
if (num_segments < kSegmentsThreshold) {
for (size_t i = 0; i < num_segments; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; ++j) {
size_t prod_init_addr = input_addr_base + j;
T1 prod_value = input_x_data_addr[prod_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = prod_init_addr + k * num_compare_per;
prod_value *= input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
}
};
if (num_compare_per < kDataSizeThreshold) {
task(0, num_compare_per);
} else {
CPUKernelUtils::ParallelFor(task, num_compare_per);
}
}
} else {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
for (size_t j = 0; j < num_compare_per; ++j) {
size_t prod_init_addr = input_addr_base + j;
T1 prod_value = input_x_data_addr[prod_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = prod_init_addr + k * num_compare_per;
prod_value *= input_x_data_addr[tmp_addr];
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
}
}
};
CPUKernelUtils::ParallelFor(task, num_segments);
}
return true;
}
template <typename T1, typename T2>
bool SegmentProdCPUKernelMod::LaunchKernel_CP(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
std::vector<int64_t> segments;
int64_t seg_tmp = 1;
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
seg_tmp++;
} else {
segments.push_back(seg_tmp);
seg_tmp = 1;
}
const size_t last_loc = 2;
if (i == segment_ids_num_ - last_loc) {
segments.push_back(seg_tmp);
}
}
if (segment_ids_num_ == 1) {
segments.push_back(seg_tmp);
}
for (size_t i = 0; i < output_num_; ++i) {
output_data_addr[i] = static_cast<T1>(1);
}
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
const size_t num_segments = segments.size();
if (num_segments < kSegmentsThreshold) {
for (size_t i = 0; i < num_segments; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; ++j) {
size_t prod_init_addr = input_addr_base + j;
T1 prod_value = input_x_data_addr[prod_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = prod_init_addr + k * num_compare_per;
prod_value = ComputeProd(prod_value, input_x_data_addr[tmp_addr]);
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
}
};
if (num_compare_per < kDataSizeThreshold) {
task(0, num_compare_per);
} else {
CPUKernelUtils::ParallelFor(task, num_compare_per);
}
}
} else {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const size_t count = segments[i];
size_t count_no = 0;
for (size_t j = 0; j < i; ++j) {
count_no += segments[j];
}
size_t input_addr_base = count_no * num_compare_per;
for (size_t j = 0; j < num_compare_per; ++j) {
size_t prod_init_addr = input_addr_base + j;
T1 prod_value = input_x_data_addr[prod_init_addr];
for (size_t k = 1; k < count; ++k) {
int tmp_addr = prod_init_addr + k * num_compare_per;
prod_value = ComputeProd(prod_value, input_x_data_addr[tmp_addr]);
}
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
}
}
};
CPUKernelUtils::ParallelFor(task, num_segments);
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SegmentProd, SegmentProdCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include <complex>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SegmentProdCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SegmentProdCPUKernelMod() = default;
~SegmentProdCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T1, typename T2>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T1, typename T2>
bool LaunchKernel_CP(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::vector<size_t> input_x_shape_;
std::vector<size_t> segment_ids_shape_;
std::vector<size_t> output_shape_;
size_t input_x_num_;
size_t segment_ids_num_;
size_t output_num_;
TypeId input_x_dtype_{kTypeUnknown};
TypeId segment_ids_dtype_{kTypeUnknown};
TypeId output_dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_

View File

@ -93,10 +93,14 @@ PrimShapeDependMap &GetHostDependsMap() {
static const auto &kBlackmanWindow = prim::kPrimBlackmanWindow->name();
static const auto &kExpand = prim::kPrimExpand->name();
static const auto &kSspaddmm = prim::kPrimSspaddmm->name();
static const auto &kSegmentMean = prim::kPrimSegmentMean->name();
static const auto &kSegmentProd = prim::kPrimSegmentProd->name();
// Common host depends.
static PrimShapeDependMap host_depends{{kSegmentMax, ShapeSet{1}},
{kSegmentMin, ShapeSet{1}},
{kSegmentSum, ShapeSet{1}},
{kSegmentMean, ShapeSet{1}},
{kSegmentProd, ShapeSet{1}},
{kUnsortedSegmentSum, ShapeSet{2}},
{kFractionalAvgPoolGrad, ShapeSet{0}},
{kUnsortedSegmentMin, ShapeSet{2}},

View File

@ -113,6 +113,8 @@ constexpr auto kFillDiagonal = "FillDiagonal";
constexpr auto kSegmentMax = "SegmentMax";
constexpr auto kSegmentSum = "SegmentSum";
constexpr auto kSegmentMin = "SegmentMin";
constexpr auto kSegmentMean = "SegmentMean";
constexpr auto kSegmentProd = "SegmentProd";
constexpr auto kDynamicShape = "DynamicShape";
constexpr auto kTensorShape = "TensorShape";
constexpr auto kCheckNumerics = "CheckNumerics";
@ -481,6 +483,8 @@ GVAR_DEF(PrimitivePtr, kPrimMeshgrid, std::make_shared<Primitive>(kMeshgrid));
GVAR_DEF(PrimitivePtr, kPrimSegmentMax, std::make_shared<Primitive>(kSegmentMax));
GVAR_DEF(PrimitivePtr, kPrimSegmentMin, std::make_shared<Primitive>(kSegmentMin));
GVAR_DEF(PrimitivePtr, kPrimSegmentSum, std::make_shared<Primitive>(kSegmentSum));
GVAR_DEF(PrimitivePtr, kPrimSegmentMean, std::make_shared<Primitive>(kSegmentMean));
GVAR_DEF(PrimitivePtr, kPrimSegmentProd, std::make_shared<Primitive>(kSegmentProd));
// image
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));

View File

@ -0,0 +1,144 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/segment_mean.h"
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SegmentMeanInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
auto x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape_ptr);
auto segment_ids_shape_ptr = input_args[1]->BuildShape();
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto prim_name = primitive->name();
if (x_shape.size() == 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of input_x must not be less than 1, but got 0.";
}
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(segment_ids_shape_ptr)[kShape];
ShapeVector out_shape(x_shape);
auto segment_ids_ptr = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(segment_ids_ptr);
if (!segment_ids_ptr->isa<AnyValue>() && !segment_ids_ptr->isa<None>()) {
if (segment_ids_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be a 1D tensor, but got "
<< segment_ids_shape.size() << "D tensor";
}
if (segment_ids_shape[0] != x_shape[0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the amount of data for segment_ids must be equal to the first dimension of the "
"shape of input_x, but got "
<< segment_ids_shape[0] << " and " << x_shape[0] << ".";
}
auto segment_ids_tensor = segment_ids_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(segment_ids_tensor);
auto data_size = segment_ids_tensor->DataSize();
auto segment_ids_type_id = segment_ids_tensor->data_type();
if (segment_ids_type_id == kNumberTypeInt64) {
int64_t *segment_ids_data = reinterpret_cast<int64_t *>(segment_ids_tensor->data_c());
if (segment_ids_data[0] < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
<< segment_ids_data[0] << ".";
}
for (size_t i = 0; i < data_size - 1; ++i) {
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
break;
}
}
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
} else if (segment_ids_type_id == kNumberTypeInt32) {
int32_t *segment_ids_data = reinterpret_cast<int32_t *>(segment_ids_tensor->data_c());
if (segment_ids_data[0] < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
<< segment_ids_data[0] << ".";
}
for (size_t i = 0; i < data_size - 1; ++i) {
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
break;
}
}
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
}
uint32_t length = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
length *= out_shape[i];
}
if (length > max_length) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', The number of elements of output must be less than max length: " << max_length
<< ", but got " << length
<< "! The shape of output should be reduced or max_length should be increased";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
uint32_t length = 1;
for (size_t i = 1; i < x_shape.size(); ++i) {
length *= x_shape[i];
}
const uint32_t max_shape_value = static_cast<uint32_t>(max_length) / length;
ShapeVector min_shape(x_shape);
ShapeVector max_shape(x_shape);
out_shape[0] = abstract::Shape::SHP_ANY;
min_shape[0] = 1;
max_shape[0] = max_shape_value;
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
}
}
TypePtr SegmentMeanInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
TypePtr x_type = input_args[0]->BuildType();
TypePtr segment_ids_type = input_args[1]->BuildType();
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kComplex128, kInt32,
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kComplex64};
const std::set<TypePtr> segment_ids_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, x_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids_type", segment_ids_type, segment_ids_valid_types,
prim_name);
return x_type;
}
} // namespace
MIND_API_BASE_IMPL(SegmentMean, PrimitiveC, BaseOperator);
AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto type = SegmentMeanInferType(primitive, input_args);
auto shape = SegmentMeanInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SegmentMean, prim::kPrimSegmentMean, SegmentMeanInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_
#define MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSegmentMean = "SegmentMean";
class MIND_API SegmentMean : public BaseOperator {
public:
SegmentMean() : BaseOperator(kNameSegmentMean) { InitIOName({"input_x", "segment_ids"}, {"output"}); }
MIND_API_BASE_MEMBER(SegmentMean);
};
abstract::AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSegmentMeanPtr = std::shared_ptr<SegmentMean>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_

View File

@ -0,0 +1,144 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/segment_prod.h"
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SegmentProdInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
auto x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape_ptr);
auto segment_ids_shape_ptr = input_args[1]->BuildShape();
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto prim_name = primitive->name();
if (x_shape.size() == 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of input_x must not be less than 1, but got 0.";
}
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(segment_ids_shape_ptr)[kShape];
ShapeVector out_shape(x_shape);
auto segment_ids_ptr = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(segment_ids_ptr);
if (!segment_ids_ptr->isa<AnyValue>() && !segment_ids_ptr->isa<None>()) {
if (segment_ids_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be a 1D tensor, but got "
<< segment_ids_shape.size() << "D tensor";
}
if (segment_ids_shape[0] != x_shape[0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the amount of data for segment_ids must be equal to the first dimension of the "
"shape of input_x, but got "
<< segment_ids_shape[0] << " and " << x_shape[0] << ".";
}
auto segment_ids_tensor = segment_ids_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(segment_ids_tensor);
auto data_size = segment_ids_tensor->DataSize();
auto segment_ids_type_id = segment_ids_tensor->data_type();
if (segment_ids_type_id == kNumberTypeInt64) {
int64_t *segment_ids_data = reinterpret_cast<int64_t *>(segment_ids_tensor->data_c());
if (segment_ids_data[0] < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
<< segment_ids_data[0] << ".";
}
for (size_t i = 0; i < data_size - 1; ++i) {
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
break;
}
}
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
} else if (segment_ids_type_id == kNumberTypeInt32) {
int32_t *segment_ids_data = reinterpret_cast<int32_t *>(segment_ids_tensor->data_c());
if (segment_ids_data[0] < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
<< segment_ids_data[0] << ".";
}
for (size_t i = 0; i < data_size - 1; ++i) {
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
break;
}
}
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
}
uint32_t length = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
length *= out_shape[i];
}
if (length > max_length) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the number of elements of output must be less than max length: " << max_length
<< ", but got " << length
<< "! The shape of output should be reduced or max_length should be increased";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
uint32_t length = 1;
for (size_t i = 1; i < x_shape.size(); ++i) {
length *= x_shape[i];
}
const uint32_t max_shape_value = static_cast<uint32_t>(max_length) / length;
ShapeVector min_shape(x_shape);
ShapeVector max_shape(x_shape);
out_shape[0] = abstract::Shape::SHP_ANY;
min_shape[0] = 1;
max_shape[0] = max_shape_value;
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
}
}
TypePtr SegmentProdInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
TypePtr x_type = input_args[0]->BuildType();
TypePtr segment_ids_type = input_args[1]->BuildType();
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kComplex128, kInt32,
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kComplex64};
const std::set<TypePtr> segment_ids_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, x_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids_type", segment_ids_type, segment_ids_valid_types,
prim_name);
return x_type;
}
} // namespace
MIND_API_BASE_IMPL(SegmentProd, PrimitiveC, BaseOperator);
AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto type = SegmentProdInferType(primitive, input_args);
auto shape = SegmentProdInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SegmentProd, prim::kPrimSegmentProd, SegmentProdInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEGMENT_PROD_H_
#define MINDSPORE_CORE_OPS_SEGMENT_PROD_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSegmentProd = "SegmentProd";
class MIND_API SegmentProd : public BaseOperator {
public:
SegmentProd() : BaseOperator(kNameSegmentProd) { InitIOName({"input_x", "segment_ids"}, {"output"}); }
MIND_API_BASE_MEMBER(SegmentProd);
};
abstract::AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSegmentProdPtr = std::shared_ptr<SegmentProd>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEGMENT_PROD_H_

View File

@ -32,6 +32,7 @@ from ..operations.array_ops import SegmentMax
from ..operations.array_ops import SegmentMin
from ..operations.array_ops import SegmentSum
from ..operations.array_ops import Expand
from ..operations.array_ops import SegmentMean
from .. import functional as F
from .. import operations as P
from .._utils.utils import is_shape_unknown
@ -394,3 +395,29 @@ def get_bprop_expand(self):
return dx, dshape
return bprop
@bprop_getters.register(SegmentMean)
def get_bprop_segment_mean(self):
"""Generate bprop for SegmentMean"""
rank = P.Rank()
shape = P.Shape()
fill = P.Fill()
divide = P.Div()
segment_sum = SegmentSum()
gather = P.Gather()
cast = P.Cast()
def bprop(input_x, segment_ids, output, dout):
input_x_type = F.dtype(input_x)
input_x = cast(input_x, mstype.float32)
dout = cast(dout, mstype.float32)
dout_type = F.dtype(dout)
input_rank = rank(input_x)
ones_shape = shape(segment_ids)
ones_shape = ones_shape + (1,) * (input_rank - 1)
ones = fill(dout_type, ones_shape, 1)
scaled_grad = divide(dout, segment_sum(ones, segment_ids))
return cast(gather(scaled_grad, segment_ids, 0), input_x_type), zeros_like(segment_ids)
return bprop

View File

@ -206,7 +206,9 @@ from .adjust_saturation import _adjust_saturation_aicpu
from .grid_sampler_2d import _grid_sampler_2d_aicpu
from .grid_sampler_2d_grad import _grid_sampler_2d_grad_aicpu
from .segment_max import _segment_max_aicpu
from .segment_mean import _segment_mean_aicpu
from .segment_min import _segment_min_aicpu
from .segment_prod import _segment_prod_aicpu
from .segment_sum import _segment_sum_aicpu
from .scatter_nd_max import _scatter_nd_max_aicpu
from .conj import _conj_aicpu

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SegmentMean op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
segment_mean_op_info = AiCPURegOp("SegmentMean") \
.fusion_type("OPAQUE") \
.input(0, "input_x", "required") \
.input(1, "segment_ids", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default)\
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default)\
.get_op_info()
@op_info_register(segment_mean_op_info)
def _segment_mean_aicpu():
"""SegmentMean AiCPU register"""
return

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SegmentProd op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
segment_prod_op_info = AiCPURegOp("SegmentProd") \
.fusion_type("OPAQUE") \
.input(0, "input_x", "required") \
.input(1, "segment_ids", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default)\
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default)\
.get_op_info()
@op_info_register(segment_prod_op_info)
def _segment_prod_aicpu():
"""SegmentProd AiCPU register"""
return

View File

@ -7835,3 +7835,111 @@ class FillDiagonal(Primitive):
self.fill_value = fill_value
validator.check_value_type('wrap', wrap, [bool], self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['y'])
class SegmentMean(Primitive):
r"""
Computes the mean along segments of a tensor.
Computes a tensor such that :math:`output_i = \mean_j data_j` where mean is over :math:`j` such that
:math:`segment\_ids[j] == i`. If the mean is empty for a given segment ID :math:`i`, :math:`output[i] = 0`.
.. warning::
If the dtype of `input_x` is complex number, the gradient can not be calculated.
Inputs:
- **input_x** (Tensor) - The input tensor whose dtype is real number or complex number and whose rank is not
less than 1.
- **segment_ids** (Tensor) - A 1-D tensor whose dtype is int32 or int64. The size of tensor must be equal to
the first dimension of the shape of `input_x`. Values must be sorted in ascending order and need not cover
all values in the full range of valid values, but must be positive intege. Only constant values is allowed.
Outputs:
Tensor, whose dtype and the dimension of the shape is the same as `input_x`. The first dimension of the shape
is equal to the value of the last element of `segment_ids` plus one, and the other dimensions are the same as
those of `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If `segment_ids` is not a Tensor.
TypeError: If the dtype of `input_x` is invalid.
TypeError: If the dtype of `segment_ids` is invalid.
ValueError: If the rank of `input_x` is less than 1.
ValueError: If the rank of `segment_ids` is not equal to 1.
ValueError: If the size of `segment_ids` is not equal to the first dimension of the shape of `input_x`.
ValueError: If the values of `segment_ids` are negative.
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [1, 2, 3], [7, 8, 9]], mstype.float64)
>>> segment_ids = Tensor([0, 0, 2], mstype.int64)
>>> op = ops.SegmentMean()
>>> output = op(x, segment_ids)
>>> print(output)
[[1. 2. 3.]
[0. 0. 0.]
[7. 8. 9.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SegmentMean"""
self.add_prim_attr("max_length", 1000000)
self.init_prim_io_names(inputs=['input_x', 'segment_ids'], outputs=['output'])
class SegmentProd(Primitive):
r"""
Computes the prod along segments of a tensor.
Computes a tensor such that :math:`output_i = \prod_j data_j` where prod is over :math:`j` such that
:math:`segment\_ids[j] == i`. If the prod is empty for a given segment ID :math:`i`, :math:`output[i] = 0`.
.. warning::
If the dtype of `input_x` is complex number, the gradient can not be calculated.
Inputs:
- **input_x** (Tensor) - The input tensor whose dtype is real number or complex number and whose rank is not
less than 1.
- **segment_ids** (Tensor) - A 1-D tensor whose dtype is int32 or int64. The size of tensor must be equal to
the first dimension of the shape of `input_x`. Values must be sorted in ascending order and need not cover
all values in the full range of valid values, but must be positive intege. Only constant values is allowed.
Outputs:
Tensor, whose dtype and the dimension of the shape is the same as `input_x`. The first dimension of the shape
is equal to the value of the last element of `segment_ids` plus one, and the other dimensions are the same as
those of `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If `segment_ids` is not a Tensor.
TypeError: If the dtype of `input_x` is invalid.
TypeError: If the dtype of `segment_ids` is invalid.
ValueError: If the rank of `input_x` is less than 1.
ValueError: If the rank of `segment_ids` is not equal to 1.
ValueError: If the size of `segment_ids` is not equal to the first dimension of the shape of `input_x`.
ValueError: If the values of `segment_ids` are negative.
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
>>> segment_ids = Tensor([0, 0, 2], mstype.int64)
>>> op = ops.SegmentProd()
>>> output = op(x, segment_ids)
>>> print(output)
[[ 4. 10. 18.]
[ 1. 1. 1.]
[ 7. 8. 9.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SegmentProd"""
self.add_prim_attr("max_length", 1000000)
self.init_prim_io_names(inputs=['input_x', 'segment_ids'], outputs=['output'])

View File

@ -46,6 +46,8 @@ from mindspore.ops.operations.array_ops import SegmentMax
from mindspore.ops.operations.array_ops import SegmentMin
from mindspore.ops.operations.array_ops import SegmentSum
from mindspore.ops.operations.array_ops import IdentityN
from mindspore.ops.operations.array_ops import SegmentMean
from mindspore.ops.operations.array_ops import SegmentProd
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.random_ops import TruncatedNormal
from mindspore.ops.operations.other_ops import SampleDistortedBoundingBoxV2
@ -3265,6 +3267,19 @@ test_case_array_ops = [
Tensor([0, 0, 2], mstype.int64)],
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
}),
('SegmentMean', {
'block': SegmentMean(),
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8),
Tensor([0, 0, 2], mstype.int64)],
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
}),
('SegmentProd', {
'block': SegmentProd(),
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8),
Tensor([0, 0, 2], mstype.int64)],
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
'skip': ['backward']
}),
]
test_case_image_ops = [