forked from mindspore-Ecosystem/mindspore
!33227 [assistant][ops] Add SegmentMean, SegmentProd
Merge pull request !33227 from 王乐/SegmentMean
This commit is contained in:
commit
5bd1964bf7
|
@ -36,6 +36,8 @@ constexpr auto kFractionalAvgPoolGradOpName = "FractionalAvgPoolGrad";
|
|||
constexpr auto kSegmentMaxOpName = "SegmentMax";
|
||||
constexpr auto kSegmentMinOpName = "SegmentMin";
|
||||
constexpr auto kSegmentSumOpName = "SegmentSum";
|
||||
constexpr auto kSegmentMeanOpName = "SegmentMean";
|
||||
constexpr auto kSegmentProdOpName = "SegmentProd";
|
||||
constexpr auto kConcatOpName = "Concat";
|
||||
constexpr auto kListDiffOpName = "ListDiff";
|
||||
constexpr auto kUniqueOpName = "Unique";
|
||||
|
@ -847,6 +849,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
|||
kSegmentMaxOpName,
|
||||
kSegmentMinOpName,
|
||||
kSegmentSumOpName,
|
||||
kSegmentMeanOpName,
|
||||
kSegmentProdOpName,
|
||||
kNonZeroOpName};
|
||||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/segment_mean_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace {
|
||||
const size_t kSegmentsThreshold = 2 * 1024;
|
||||
const size_t kDataSizeThreshold = 2 * 1024;
|
||||
|
||||
#define SEGMENTMEAN_COMPUTE_CASE(DTYPE, TYPE1, TYPE2) \
|
||||
case (DTYPE): { \
|
||||
ret = LaunchKernel<TYPE1, TYPE2>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define SEGMENTMEAN_COMPUTE_CASE_CP(DTYPE, TYPE1, TYPE2) \
|
||||
case (DTYPE): { \
|
||||
ret = LaunchKernelComplex<TYPE1, TYPE2>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
T ComplexDiv(T sum, size_t num) {
|
||||
if (num != 0) {
|
||||
T res;
|
||||
auto real = sum.real();
|
||||
auto imag = sum.imag();
|
||||
res.real(real / num);
|
||||
res.imag(imag / num);
|
||||
return res;
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For 'SegmentMean', divisor can not be 0.";
|
||||
}
|
||||
}
|
||||
|
||||
void SegmentMeanCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
input_x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
input_x_num_ = CPUKernelUtils::CalcElementNum(input_x_shape_);
|
||||
segment_ids_num_ = CPUKernelUtils::CalcElementNum(segment_ids_shape_);
|
||||
output_num_ = CPUKernelUtils::CalcElementNum(output_shape_);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SegmentMeanCPUKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool SegmentMeanCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ret = true;
|
||||
switch (segment_ids_dtype_) {
|
||||
case kNumberTypeInt32: {
|
||||
switch (input_x_dtype_) {
|
||||
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt8, int8_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt16, int16_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt64, int64_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat16, float16, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat32, float, int32_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat64, double, int32_t)
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported input_x data type: " << input_x_dtype_;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt64: {
|
||||
switch (input_x_dtype_) {
|
||||
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt8, int8_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt16, int16_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt32, int32_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat16, float16, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat32, float, int64_t)
|
||||
SEGMENTMEAN_COMPUTE_CASE(kNumberTypeFloat64, double, int64_t)
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported input_x data type: " << input_x_dtype_;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentMean', unsupported segment_ids data type: " << segment_ids_dtype_;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool SegmentMeanCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
|
||||
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
|
||||
std::vector<int64_t> segments;
|
||||
int64_t seg_tmp = 1;
|
||||
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
|
||||
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
|
||||
seg_tmp++;
|
||||
} else {
|
||||
segments.push_back(seg_tmp);
|
||||
seg_tmp = 1;
|
||||
}
|
||||
const size_t last_loc = 2;
|
||||
if (i == segment_ids_num_ - last_loc) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
}
|
||||
if (segment_ids_num_ == 1) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
output_data_addr[i] = static_cast<T1>(0);
|
||||
}
|
||||
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
|
||||
const size_t num_segments = segments.size();
|
||||
if (num_segments < kSegmentsThreshold) {
|
||||
for (size_t i = 0; i < num_segments; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end; ++j) {
|
||||
size_t mean_init_addr = input_addr_base + j;
|
||||
T1 sum_value = input_x_data_addr[mean_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = mean_init_addr + k * num_compare_per;
|
||||
sum_value += input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = sum_value / static_cast<T1>(count);
|
||||
}
|
||||
};
|
||||
if (num_compare_per < kDataSizeThreshold) {
|
||||
task(0, num_compare_per);
|
||||
} else {
|
||||
CPUKernelUtils::ParallelFor(task, num_compare_per);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
for (size_t j = 0; j < num_compare_per; ++j) {
|
||||
size_t mean_init_addr = input_addr_base + j;
|
||||
T1 sum_value = input_x_data_addr[mean_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = mean_init_addr + k * num_compare_per;
|
||||
sum_value += input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = sum_value / static_cast<T1>(count);
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, num_segments);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool SegmentMeanCPUKernelMod::LaunchKernelComplex(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
|
||||
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
|
||||
std::vector<int64_t> segments;
|
||||
int64_t seg_tmp = 1;
|
||||
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
|
||||
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
|
||||
seg_tmp++;
|
||||
} else {
|
||||
segments.push_back(seg_tmp);
|
||||
seg_tmp = 1;
|
||||
}
|
||||
const size_t last_loc = 2;
|
||||
if (i == segment_ids_num_ - last_loc) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
}
|
||||
if (segment_ids_num_ == 1) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
output_data_addr[i] = static_cast<T1>(0);
|
||||
}
|
||||
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
|
||||
const size_t num_segments = segments.size();
|
||||
if (num_segments < kSegmentsThreshold) {
|
||||
for (size_t i = 0; i < num_segments; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end; ++j) {
|
||||
size_t mean_init_addr = input_addr_base + j;
|
||||
T1 sum_value = input_x_data_addr[mean_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = mean_init_addr + k * num_compare_per;
|
||||
sum_value += input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = ComplexDiv(sum_value, count);
|
||||
}
|
||||
};
|
||||
if (num_compare_per < kDataSizeThreshold) {
|
||||
task(0, num_compare_per);
|
||||
} else {
|
||||
CPUKernelUtils::ParallelFor(task, num_compare_per);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
for (size_t j = 0; j < num_compare_per; ++j) {
|
||||
size_t mean_init_addr = input_addr_base + j;
|
||||
T1 sum_value = input_x_data_addr[mean_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = mean_init_addr + k * num_compare_per;
|
||||
sum_value += input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = ComplexDiv(sum_value, count);
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, num_segments);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SegmentMean, SegmentMeanCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SegmentMeanCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SegmentMeanCPUKernelMod() = default;
|
||||
~SegmentMeanCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_x_shape_;
|
||||
std::vector<size_t> segment_ids_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
size_t input_x_num_;
|
||||
size_t segment_ids_num_;
|
||||
size_t output_num_;
|
||||
TypeId input_x_dtype_{kTypeUnknown};
|
||||
TypeId segment_ids_dtype_{kTypeUnknown};
|
||||
TypeId output_dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_MEAN_CPU_KERNEL_H_
|
|
@ -0,0 +1,309 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/segment_prod_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace {
|
||||
const size_t kSegmentsThreshold = 2 * 1024;
|
||||
const size_t kDataSizeThreshold = 2 * 1024;
|
||||
|
||||
template <typename T>
|
||||
T ComputeProd(const T num_1, const T num_2) {
|
||||
T res;
|
||||
auto a = num_1.real();
|
||||
auto b = num_1.imag();
|
||||
auto x = num_2.real();
|
||||
auto y = num_2.imag();
|
||||
auto real_res = a * x - b * y;
|
||||
auto imag_res = b * x + a * y;
|
||||
res.real(real_res);
|
||||
res.imag(imag_res);
|
||||
return res;
|
||||
}
|
||||
|
||||
#define SEGMENTPROD_COMPUTE_CASE(DTYPE, TYPE1, TYPE2) \
|
||||
case (DTYPE): { \
|
||||
ret = LaunchKernel<TYPE1, TYPE2>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define SEGMENTPROD_COMPUTE_CASE_CP(DTYPE, TYPE1, TYPE2) \
|
||||
case (DTYPE): { \
|
||||
ret = LaunchKernel_CP<TYPE1, TYPE2>(inputs, outputs); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void SegmentProdCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
input_x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
segment_ids_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
input_x_num_ = CPUKernelUtils::CalcElementNum(input_x_shape_);
|
||||
segment_ids_num_ = CPUKernelUtils::CalcElementNum(segment_ids_shape_);
|
||||
output_num_ = CPUKernelUtils::CalcElementNum(output_shape_);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SegmentProdCPUKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool SegmentProdCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ret = true;
|
||||
switch (segment_ids_dtype_) {
|
||||
case kNumberTypeInt32: {
|
||||
switch (input_x_dtype_) {
|
||||
SEGMENTPROD_COMPUTE_CASE_CP(kNumberTypeComplex64, std::complex<float>, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE_CP(kNumberTypeComplex128, std::complex<double>, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt8, int8_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt16, int16_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt32, int32_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt64, int64_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat16, float16, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat32, float, int32_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat64, double, int32_t)
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported input_x data type: " << input_x_dtype_;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt64: {
|
||||
switch (input_x_dtype_) {
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeComplex64, std::complex<float>, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeComplex128, std::complex<double>, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt8, int8_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt16, int16_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt32, int32_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeInt64, int64_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat16, float16, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat32, float, int64_t)
|
||||
SEGMENTPROD_COMPUTE_CASE(kNumberTypeFloat64, double, int64_t)
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported input_x data type: " << input_x_dtype_;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For 'SegmentProd', unsupported segment_ids data type: " << segment_ids_dtype_;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool SegmentProdCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
|
||||
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
|
||||
std::vector<int64_t> segments;
|
||||
int64_t seg_tmp = 1;
|
||||
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
|
||||
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
|
||||
seg_tmp++;
|
||||
} else {
|
||||
segments.push_back(seg_tmp);
|
||||
seg_tmp = 1;
|
||||
}
|
||||
const size_t last_loc = 2;
|
||||
if (i == segment_ids_num_ - last_loc) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
}
|
||||
if (segment_ids_num_ == 1) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
output_data_addr[i] = static_cast<T1>(1);
|
||||
}
|
||||
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
|
||||
const size_t num_segments = segments.size();
|
||||
if (num_segments < kSegmentsThreshold) {
|
||||
for (size_t i = 0; i < num_segments; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end; ++j) {
|
||||
size_t prod_init_addr = input_addr_base + j;
|
||||
T1 prod_value = input_x_data_addr[prod_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = prod_init_addr + k * num_compare_per;
|
||||
prod_value *= input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
|
||||
}
|
||||
};
|
||||
if (num_compare_per < kDataSizeThreshold) {
|
||||
task(0, num_compare_per);
|
||||
} else {
|
||||
CPUKernelUtils::ParallelFor(task, num_compare_per);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
for (size_t j = 0; j < num_compare_per; ++j) {
|
||||
size_t prod_init_addr = input_addr_base + j;
|
||||
T1 prod_value = input_x_data_addr[prod_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = prod_init_addr + k * num_compare_per;
|
||||
prod_value *= input_x_data_addr[tmp_addr];
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, num_segments);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool SegmentProdCPUKernelMod::LaunchKernel_CP(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x_data_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
auto segment_ids_data_addr = reinterpret_cast<T2 *>(inputs[1]->addr);
|
||||
auto output_data_addr = reinterpret_cast<T1 *>(outputs[0]->addr);
|
||||
std::vector<int64_t> segments;
|
||||
int64_t seg_tmp = 1;
|
||||
for (size_t i = 0; i < segment_ids_num_ - 1; ++i) {
|
||||
if (segment_ids_data_addr[i] == segment_ids_data_addr[i + 1]) {
|
||||
seg_tmp++;
|
||||
} else {
|
||||
segments.push_back(seg_tmp);
|
||||
seg_tmp = 1;
|
||||
}
|
||||
const size_t last_loc = 2;
|
||||
if (i == segment_ids_num_ - last_loc) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
}
|
||||
if (segment_ids_num_ == 1) {
|
||||
segments.push_back(seg_tmp);
|
||||
}
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
output_data_addr[i] = static_cast<T1>(1);
|
||||
}
|
||||
const size_t num_compare_per = input_x_num_ / input_x_shape_[0];
|
||||
const size_t num_segments = segments.size();
|
||||
if (num_segments < kSegmentsThreshold) {
|
||||
for (size_t i = 0; i < num_segments; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end; ++j) {
|
||||
size_t prod_init_addr = input_addr_base + j;
|
||||
T1 prod_value = input_x_data_addr[prod_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = prod_init_addr + k * num_compare_per;
|
||||
prod_value = ComputeProd(prod_value, input_x_data_addr[tmp_addr]);
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
|
||||
}
|
||||
};
|
||||
if (num_compare_per < kDataSizeThreshold) {
|
||||
task(0, num_compare_per);
|
||||
} else {
|
||||
CPUKernelUtils::ParallelFor(task, num_compare_per);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const size_t count = segments[i];
|
||||
size_t count_no = 0;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
count_no += segments[j];
|
||||
}
|
||||
size_t input_addr_base = count_no * num_compare_per;
|
||||
for (size_t j = 0; j < num_compare_per; ++j) {
|
||||
size_t prod_init_addr = input_addr_base + j;
|
||||
T1 prod_value = input_x_data_addr[prod_init_addr];
|
||||
for (size_t k = 1; k < count; ++k) {
|
||||
int tmp_addr = prod_init_addr + k * num_compare_per;
|
||||
prod_value = ComputeProd(prod_value, input_x_data_addr[tmp_addr]);
|
||||
}
|
||||
output_data_addr[segment_ids_data_addr[count_no] * num_compare_per + j] = prod_value;
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, num_segments);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SegmentProd, SegmentProdCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SegmentProdCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SegmentProdCPUKernelMod() = default;
|
||||
~SegmentProdCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernel_CP(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_x_shape_;
|
||||
std::vector<size_t> segment_ids_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
size_t input_x_num_;
|
||||
size_t segment_ids_num_;
|
||||
size_t output_num_;
|
||||
TypeId input_x_dtype_{kTypeUnknown};
|
||||
TypeId segment_ids_dtype_{kTypeUnknown};
|
||||
TypeId output_dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEGMENT_PROD_CPU_KERNEL_H_
|
|
@ -93,10 +93,14 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
static const auto &kBlackmanWindow = prim::kPrimBlackmanWindow->name();
|
||||
static const auto &kExpand = prim::kPrimExpand->name();
|
||||
static const auto &kSspaddmm = prim::kPrimSspaddmm->name();
|
||||
static const auto &kSegmentMean = prim::kPrimSegmentMean->name();
|
||||
static const auto &kSegmentProd = prim::kPrimSegmentProd->name();
|
||||
// Common host depends.
|
||||
static PrimShapeDependMap host_depends{{kSegmentMax, ShapeSet{1}},
|
||||
{kSegmentMin, ShapeSet{1}},
|
||||
{kSegmentSum, ShapeSet{1}},
|
||||
{kSegmentMean, ShapeSet{1}},
|
||||
{kSegmentProd, ShapeSet{1}},
|
||||
{kUnsortedSegmentSum, ShapeSet{2}},
|
||||
{kFractionalAvgPoolGrad, ShapeSet{0}},
|
||||
{kUnsortedSegmentMin, ShapeSet{2}},
|
||||
|
|
|
@ -113,6 +113,8 @@ constexpr auto kFillDiagonal = "FillDiagonal";
|
|||
constexpr auto kSegmentMax = "SegmentMax";
|
||||
constexpr auto kSegmentSum = "SegmentSum";
|
||||
constexpr auto kSegmentMin = "SegmentMin";
|
||||
constexpr auto kSegmentMean = "SegmentMean";
|
||||
constexpr auto kSegmentProd = "SegmentProd";
|
||||
constexpr auto kDynamicShape = "DynamicShape";
|
||||
constexpr auto kTensorShape = "TensorShape";
|
||||
constexpr auto kCheckNumerics = "CheckNumerics";
|
||||
|
@ -481,6 +483,8 @@ GVAR_DEF(PrimitivePtr, kPrimMeshgrid, std::make_shared<Primitive>(kMeshgrid));
|
|||
GVAR_DEF(PrimitivePtr, kPrimSegmentMax, std::make_shared<Primitive>(kSegmentMax));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentMin, std::make_shared<Primitive>(kSegmentMin));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentSum, std::make_shared<Primitive>(kSegmentSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentMean, std::make_shared<Primitive>(kSegmentMean));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSegmentProd, std::make_shared<Primitive>(kSegmentProd));
|
||||
|
||||
// image
|
||||
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/segment_mean.h"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SegmentMeanInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto x_shape_ptr = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||
auto segment_ids_shape_ptr = input_args[1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto prim_name = primitive->name();
|
||||
if (x_shape.size() == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of input_x must not be less than 1, but got 0.";
|
||||
}
|
||||
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(segment_ids_shape_ptr)[kShape];
|
||||
ShapeVector out_shape(x_shape);
|
||||
auto segment_ids_ptr = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_ptr);
|
||||
if (!segment_ids_ptr->isa<AnyValue>() && !segment_ids_ptr->isa<None>()) {
|
||||
if (segment_ids_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be a 1D tensor, but got "
|
||||
<< segment_ids_shape.size() << "D tensor";
|
||||
}
|
||||
if (segment_ids_shape[0] != x_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the amount of data for segment_ids must be equal to the first dimension of the "
|
||||
"shape of input_x, but got "
|
||||
<< segment_ids_shape[0] << " and " << x_shape[0] << ".";
|
||||
}
|
||||
auto segment_ids_tensor = segment_ids_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_tensor);
|
||||
auto data_size = segment_ids_tensor->DataSize();
|
||||
auto segment_ids_type_id = segment_ids_tensor->data_type();
|
||||
if (segment_ids_type_id == kNumberTypeInt64) {
|
||||
int64_t *segment_ids_data = reinterpret_cast<int64_t *>(segment_ids_tensor->data_c());
|
||||
if (segment_ids_data[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
|
||||
<< segment_ids_data[0] << ".";
|
||||
}
|
||||
for (size_t i = 0; i < data_size - 1; ++i) {
|
||||
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
|
||||
} else if (segment_ids_type_id == kNumberTypeInt32) {
|
||||
int32_t *segment_ids_data = reinterpret_cast<int32_t *>(segment_ids_tensor->data_c());
|
||||
if (segment_ids_data[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
|
||||
<< segment_ids_data[0] << ".";
|
||||
}
|
||||
for (size_t i = 0; i < data_size - 1; ++i) {
|
||||
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
|
||||
}
|
||||
uint32_t length = 1;
|
||||
for (size_t i = 0; i < out_shape.size(); ++i) {
|
||||
length *= out_shape[i];
|
||||
}
|
||||
if (length > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', The number of elements of output must be less than max length: " << max_length
|
||||
<< ", but got " << length
|
||||
<< "! The shape of output should be reduced or max_length should be increased";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
uint32_t length = 1;
|
||||
for (size_t i = 1; i < x_shape.size(); ++i) {
|
||||
length *= x_shape[i];
|
||||
}
|
||||
const uint32_t max_shape_value = static_cast<uint32_t>(max_length) / length;
|
||||
ShapeVector min_shape(x_shape);
|
||||
ShapeVector max_shape(x_shape);
|
||||
out_shape[0] = abstract::Shape::SHP_ANY;
|
||||
min_shape[0] = 1;
|
||||
max_shape[0] = max_shape_value;
|
||||
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SegmentMeanInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
TypePtr x_type = input_args[0]->BuildType();
|
||||
TypePtr segment_ids_type = input_args[1]->BuildType();
|
||||
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kComplex128, kInt32,
|
||||
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kComplex64};
|
||||
const std::set<TypePtr> segment_ids_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, x_valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids_type", segment_ids_type, segment_ids_valid_types,
|
||||
prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentMean, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto type = SegmentMeanInferType(primitive, input_args);
|
||||
auto shape = SegmentMeanInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SegmentMean, prim::kPrimSegmentMean, SegmentMeanInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_
|
||||
#define MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSegmentMean = "SegmentMean";
|
||||
class MIND_API SegmentMean : public BaseOperator {
|
||||
public:
|
||||
SegmentMean() : BaseOperator(kNameSegmentMean) { InitIOName({"input_x", "segment_ids"}, {"output"}); }
|
||||
MIND_API_BASE_MEMBER(SegmentMean);
|
||||
};
|
||||
abstract::AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSegmentMeanPtr = std::shared_ptr<SegmentMean>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEGMENT_MEAN_H_
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/segment_prod.h"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SegmentProdInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto x_shape_ptr = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||
auto segment_ids_shape_ptr = input_args[1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto prim_name = primitive->name();
|
||||
if (x_shape.size() == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of input_x must not be less than 1, but got 0.";
|
||||
}
|
||||
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(segment_ids_shape_ptr)[kShape];
|
||||
ShapeVector out_shape(x_shape);
|
||||
auto segment_ids_ptr = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_ptr);
|
||||
if (!segment_ids_ptr->isa<AnyValue>() && !segment_ids_ptr->isa<None>()) {
|
||||
if (segment_ids_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be a 1D tensor, but got "
|
||||
<< segment_ids_shape.size() << "D tensor";
|
||||
}
|
||||
if (segment_ids_shape[0] != x_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the amount of data for segment_ids must be equal to the first dimension of the "
|
||||
"shape of input_x, but got "
|
||||
<< segment_ids_shape[0] << " and " << x_shape[0] << ".";
|
||||
}
|
||||
auto segment_ids_tensor = segment_ids_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_tensor);
|
||||
auto data_size = segment_ids_tensor->DataSize();
|
||||
auto segment_ids_type_id = segment_ids_tensor->data_type();
|
||||
if (segment_ids_type_id == kNumberTypeInt64) {
|
||||
int64_t *segment_ids_data = reinterpret_cast<int64_t *>(segment_ids_tensor->data_c());
|
||||
if (segment_ids_data[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
|
||||
<< segment_ids_data[0] << ".";
|
||||
}
|
||||
for (size_t i = 0; i < data_size - 1; ++i) {
|
||||
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
|
||||
} else if (segment_ids_type_id == kNumberTypeInt32) {
|
||||
int32_t *segment_ids_data = reinterpret_cast<int32_t *>(segment_ids_tensor->data_c());
|
||||
if (segment_ids_data[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the values of segment_ids must be nonnegative. but got "
|
||||
<< segment_ids_data[0] << ".";
|
||||
}
|
||||
for (size_t i = 0; i < data_size - 1; ++i) {
|
||||
if (segment_ids_data[i] > segment_ids_data[i + 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', segment_ids must be a tensor with element values sorted in ascending order.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
out_shape[0] = static_cast<size_t>(segment_ids_data[data_size - 1] + 1);
|
||||
}
|
||||
uint32_t length = 1;
|
||||
for (size_t i = 0; i < out_shape.size(); ++i) {
|
||||
length *= out_shape[i];
|
||||
}
|
||||
if (length > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the number of elements of output must be less than max length: " << max_length
|
||||
<< ", but got " << length
|
||||
<< "! The shape of output should be reduced or max_length should be increased";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
uint32_t length = 1;
|
||||
for (size_t i = 1; i < x_shape.size(); ++i) {
|
||||
length *= x_shape[i];
|
||||
}
|
||||
const uint32_t max_shape_value = static_cast<uint32_t>(max_length) / length;
|
||||
ShapeVector min_shape(x_shape);
|
||||
ShapeVector max_shape(x_shape);
|
||||
out_shape[0] = abstract::Shape::SHP_ANY;
|
||||
min_shape[0] = 1;
|
||||
max_shape[0] = max_shape_value;
|
||||
return std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SegmentProdInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
TypePtr x_type = input_args[0]->BuildType();
|
||||
TypePtr segment_ids_type = input_args[1]->BuildType();
|
||||
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kComplex128, kInt32,
|
||||
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kComplex64};
|
||||
const std::set<TypePtr> segment_ids_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, x_valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids_type", segment_ids_type, segment_ids_valid_types,
|
||||
prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentProd, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto type = SegmentProdInferType(primitive, input_args);
|
||||
auto shape = SegmentProdInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SegmentProd, prim::kPrimSegmentProd, SegmentProdInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEGMENT_PROD_H_
|
||||
#define MINDSPORE_CORE_OPS_SEGMENT_PROD_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSegmentProd = "SegmentProd";
|
||||
class MIND_API SegmentProd : public BaseOperator {
|
||||
public:
|
||||
SegmentProd() : BaseOperator(kNameSegmentProd) { InitIOName({"input_x", "segment_ids"}, {"output"}); }
|
||||
MIND_API_BASE_MEMBER(SegmentProd);
|
||||
};
|
||||
abstract::AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSegmentProdPtr = std::shared_ptr<SegmentProd>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEGMENT_PROD_H_
|
|
@ -32,6 +32,7 @@ from ..operations.array_ops import SegmentMax
|
|||
from ..operations.array_ops import SegmentMin
|
||||
from ..operations.array_ops import SegmentSum
|
||||
from ..operations.array_ops import Expand
|
||||
from ..operations.array_ops import SegmentMean
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from .._utils.utils import is_shape_unknown
|
||||
|
@ -394,3 +395,29 @@ def get_bprop_expand(self):
|
|||
return dx, dshape
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SegmentMean)
|
||||
def get_bprop_segment_mean(self):
|
||||
"""Generate bprop for SegmentMean"""
|
||||
rank = P.Rank()
|
||||
shape = P.Shape()
|
||||
fill = P.Fill()
|
||||
divide = P.Div()
|
||||
segment_sum = SegmentSum()
|
||||
gather = P.Gather()
|
||||
cast = P.Cast()
|
||||
|
||||
def bprop(input_x, segment_ids, output, dout):
|
||||
input_x_type = F.dtype(input_x)
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
dout = cast(dout, mstype.float32)
|
||||
dout_type = F.dtype(dout)
|
||||
input_rank = rank(input_x)
|
||||
ones_shape = shape(segment_ids)
|
||||
ones_shape = ones_shape + (1,) * (input_rank - 1)
|
||||
ones = fill(dout_type, ones_shape, 1)
|
||||
scaled_grad = divide(dout, segment_sum(ones, segment_ids))
|
||||
return cast(gather(scaled_grad, segment_ids, 0), input_x_type), zeros_like(segment_ids)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -206,7 +206,9 @@ from .adjust_saturation import _adjust_saturation_aicpu
|
|||
from .grid_sampler_2d import _grid_sampler_2d_aicpu
|
||||
from .grid_sampler_2d_grad import _grid_sampler_2d_grad_aicpu
|
||||
from .segment_max import _segment_max_aicpu
|
||||
from .segment_mean import _segment_mean_aicpu
|
||||
from .segment_min import _segment_min_aicpu
|
||||
from .segment_prod import _segment_prod_aicpu
|
||||
from .segment_sum import _segment_sum_aicpu
|
||||
from .scatter_nd_max import _scatter_nd_max_aicpu
|
||||
from .conj import _conj_aicpu
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SegmentMean op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
segment_mean_op_info = AiCPURegOp("SegmentMean") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input_x", "required") \
|
||||
.input(1, "segment_ids", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default)\
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default)\
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(segment_mean_op_info)
|
||||
def _segment_mean_aicpu():
|
||||
"""SegmentMean AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SegmentProd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
segment_prod_op_info = AiCPURegOp("SegmentProd") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input_x", "required") \
|
||||
.input(1, "segment_ids", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.C128_Default)\
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.C128_Default)\
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(segment_prod_op_info)
|
||||
def _segment_prod_aicpu():
|
||||
"""SegmentProd AiCPU register"""
|
||||
return
|
|
@ -7835,3 +7835,111 @@ class FillDiagonal(Primitive):
|
|||
self.fill_value = fill_value
|
||||
validator.check_value_type('wrap', wrap, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['y'])
|
||||
|
||||
|
||||
class SegmentMean(Primitive):
|
||||
r"""
|
||||
Computes the mean along segments of a tensor.
|
||||
|
||||
Computes a tensor such that :math:`output_i = \mean_j data_j` where mean is over :math:`j` such that
|
||||
:math:`segment\_ids[j] == i`. If the mean is empty for a given segment ID :math:`i`, :math:`output[i] = 0`.
|
||||
|
||||
.. warning::
|
||||
If the dtype of `input_x` is complex number, the gradient can not be calculated.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor whose dtype is real number or complex number and whose rank is not
|
||||
less than 1.
|
||||
- **segment_ids** (Tensor) - A 1-D tensor whose dtype is int32 or int64. The size of tensor must be equal to
|
||||
the first dimension of the shape of `input_x`. Values must be sorted in ascending order and need not cover
|
||||
all values in the full range of valid values, but must be positive intege. Only constant values is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, whose dtype and the dimension of the shape is the same as `input_x`. The first dimension of the shape
|
||||
is equal to the value of the last element of `segment_ids` plus one, and the other dimensions are the same as
|
||||
those of `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `segment_ids` is not a Tensor.
|
||||
TypeError: If the dtype of `input_x` is invalid.
|
||||
TypeError: If the dtype of `segment_ids` is invalid.
|
||||
ValueError: If the rank of `input_x` is less than 1.
|
||||
ValueError: If the rank of `segment_ids` is not equal to 1.
|
||||
ValueError: If the size of `segment_ids` is not equal to the first dimension of the shape of `input_x`.
|
||||
ValueError: If the values of `segment_ids` are negative.
|
||||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [1, 2, 3], [7, 8, 9]], mstype.float64)
|
||||
>>> segment_ids = Tensor([0, 0, 2], mstype.int64)
|
||||
>>> op = ops.SegmentMean()
|
||||
>>> output = op(x, segment_ids)
|
||||
>>> print(output)
|
||||
[[1. 2. 3.]
|
||||
[0. 0. 0.]
|
||||
[7. 8. 9.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SegmentMean"""
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
self.init_prim_io_names(inputs=['input_x', 'segment_ids'], outputs=['output'])
|
||||
|
||||
|
||||
class SegmentProd(Primitive):
|
||||
r"""
|
||||
Computes the prod along segments of a tensor.
|
||||
|
||||
Computes a tensor such that :math:`output_i = \prod_j data_j` where prod is over :math:`j` such that
|
||||
:math:`segment\_ids[j] == i`. If the prod is empty for a given segment ID :math:`i`, :math:`output[i] = 0`.
|
||||
|
||||
.. warning::
|
||||
If the dtype of `input_x` is complex number, the gradient can not be calculated.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor whose dtype is real number or complex number and whose rank is not
|
||||
less than 1.
|
||||
- **segment_ids** (Tensor) - A 1-D tensor whose dtype is int32 or int64. The size of tensor must be equal to
|
||||
the first dimension of the shape of `input_x`. Values must be sorted in ascending order and need not cover
|
||||
all values in the full range of valid values, but must be positive intege. Only constant values is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, whose dtype and the dimension of the shape is the same as `input_x`. The first dimension of the shape
|
||||
is equal to the value of the last element of `segment_ids` plus one, and the other dimensions are the same as
|
||||
those of `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `segment_ids` is not a Tensor.
|
||||
TypeError: If the dtype of `input_x` is invalid.
|
||||
TypeError: If the dtype of `segment_ids` is invalid.
|
||||
ValueError: If the rank of `input_x` is less than 1.
|
||||
ValueError: If the rank of `segment_ids` is not equal to 1.
|
||||
ValueError: If the size of `segment_ids` is not equal to the first dimension of the shape of `input_x`.
|
||||
ValueError: If the values of `segment_ids` are negative.
|
||||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
|
||||
>>> segment_ids = Tensor([0, 0, 2], mstype.int64)
|
||||
>>> op = ops.SegmentProd()
|
||||
>>> output = op(x, segment_ids)
|
||||
>>> print(output)
|
||||
[[ 4. 10. 18.]
|
||||
[ 1. 1. 1.]
|
||||
[ 7. 8. 9.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SegmentProd"""
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
self.init_prim_io_names(inputs=['input_x', 'segment_ids'], outputs=['output'])
|
||||
|
|
|
@ -46,6 +46,8 @@ from mindspore.ops.operations.array_ops import SegmentMax
|
|||
from mindspore.ops.operations.array_ops import SegmentMin
|
||||
from mindspore.ops.operations.array_ops import SegmentSum
|
||||
from mindspore.ops.operations.array_ops import IdentityN
|
||||
from mindspore.ops.operations.array_ops import SegmentMean
|
||||
from mindspore.ops.operations.array_ops import SegmentProd
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.ops.operations.random_ops import TruncatedNormal
|
||||
from mindspore.ops.operations.other_ops import SampleDistortedBoundingBoxV2
|
||||
|
@ -3265,6 +3267,19 @@ test_case_array_ops = [
|
|||
Tensor([0, 0, 2], mstype.int64)],
|
||||
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
|
||||
}),
|
||||
('SegmentMean', {
|
||||
'block': SegmentMean(),
|
||||
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8),
|
||||
Tensor([0, 0, 2], mstype.int64)],
|
||||
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
|
||||
}),
|
||||
('SegmentProd', {
|
||||
'block': SegmentProd(),
|
||||
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8),
|
||||
Tensor([0, 0, 2], mstype.int64)],
|
||||
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
|
||||
'skip': ['backward']
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_image_ops = [
|
||||
|
|
Loading…
Reference in New Issue