!33233 [feat][assistant][I48O6V,I48O66,I48O72]add SparseSegmentSqrtN, SparseSegmentSqrtNGrad, SparseSegmentSqrtNWithNumSegments
Merge pull request !33233 from 桂宁馨/SparseSegmentSqrtN
This commit is contained in:
commit
0d50b090b2
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_segment_sqrt_n_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseSegmentSqrtNInputsNum = 3;
|
||||
constexpr size_t kSparseSegmentSqrtNOutputsNum = 1;
|
||||
|
||||
#define ADD_KERNEL(t1, t2, t3, t4) \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##t1) \
|
||||
.AddInputAttr(kNumberType##t2) \
|
||||
.AddInputAttr(kNumberType##t3) \
|
||||
.AddOutputAttr(kNumberType##t4)
|
||||
} // namespace
|
||||
|
||||
void SparseSegmentSqrtNCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSqrtNInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSqrtNOutputsNum, kernel_name_);
|
||||
}
|
||||
|
||||
void SparseSegmentSqrtNCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
dtype1_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
|
||||
dtype2_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
}
|
||||
|
||||
bool SparseSegmentSqrtNCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float16, int32_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<float16, int32_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
} else {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float16, int64_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<float16, int64_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float, int32_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<float, int32_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
} else {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float, int64_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<float, int64_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<double, int32_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<double, int32_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
} else {
|
||||
if (dtype2_ == kNumberTypeInt32) {
|
||||
LaunchKernel<double, int64_t, int32_t>(inputs, outputs);
|
||||
} else {
|
||||
LaunchKernel<double, int64_t, int64_t>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(dtype_)
|
||||
<< " which is not supported.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void SparseSegmentSqrtNCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kIndex1, std::multiplies<int>()) / x_shape_[kIndex0];
|
||||
size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
size_t k = std::accumulate(y_shape_.begin(), y_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
auto x_shape_0 = static_cast<T2>(x_shape_[kIndex0]);
|
||||
auto x_addr = reinterpret_cast<T1 *>(inputs[kIndex0]->addr);
|
||||
auto indices_addr = reinterpret_cast<T2 *>(inputs[kIndex1]->addr);
|
||||
auto segment_ids_addr = reinterpret_cast<T3 *>(inputs[kIndex2]->addr);
|
||||
auto y_addr = reinterpret_cast<T1 *>(outputs[kIndex0]->addr);
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
y_addr[i] = (T1)0;
|
||||
}
|
||||
if (segment_ids_addr[0] != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
|
||||
}
|
||||
for (size_t i = 1; i < m; i++) {
|
||||
if (segment_ids_addr[i] < segment_ids_addr[i - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
|
||||
}
|
||||
if (segment_ids_addr[i] - segment_ids_addr[i - 1] > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
if (indices_addr[i] >= x_shape_0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape.";
|
||||
}
|
||||
}
|
||||
|
||||
int oldindex = -1;
|
||||
int countnum = 0;
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
if (oldindex == segment_ids_addr[i]) {
|
||||
countnum++;
|
||||
} else {
|
||||
if (countnum != 0) {
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] /= (T1)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
countnum = 1;
|
||||
oldindex = segment_ids_addr[i];
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] = (T1)0;
|
||||
}
|
||||
}
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] += x_addr[j + indices_addr[i] * n];
|
||||
}
|
||||
}
|
||||
if (countnum != 0) {
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] /= (T1)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentSqrtNCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
ADD_KERNEL(Float16, Int32, Int32, Float16), ADD_KERNEL(Float16, Int32, Int64, Float16),
|
||||
ADD_KERNEL(Float16, Int64, Int32, Float16), ADD_KERNEL(Float16, Int64, Int64, Float16),
|
||||
ADD_KERNEL(Float32, Int32, Int32, Float32), ADD_KERNEL(Float32, Int32, Int64, Float32),
|
||||
ADD_KERNEL(Float32, Int64, Int32, Float32), ADD_KERNEL(Float32, Int64, Int64, Float16),
|
||||
ADD_KERNEL(Float64, Int32, Int32, Float64), ADD_KERNEL(Float64, Int32, Int64, Float64),
|
||||
ADD_KERNEL(Float64, Int64, Int32, Float64), ADD_KERNEL(Float64, Int64, Int64, Float64)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSqrtN, SparseSegmentSqrtNCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_CPU_KERNEL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentSqrtNCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SparseSegmentSqrtNCpuKernelMod() = default;
|
||||
~SparseSegmentSqrtNCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector segment_ids_shape_;
|
||||
ShapeVector y_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId dtype1_{kTypeUnknown};
|
||||
TypeId dtype2_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_CPU_KERNEL_H_
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_segment_sqrt_n_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseSegmentSqrtNGradInputsNum = 4;
|
||||
constexpr size_t kSparseSegmentSqrtNGradOutputsNum = 1;
|
||||
|
||||
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##t1) \
|
||||
.AddInputAttr(kNumberType##t2) \
|
||||
.AddInputAttr(kNumberType##t3) \
|
||||
.AddInputAttr(kNumberType##t4) \
|
||||
.AddOutputAttr(kNumberType##t5)
|
||||
} // namespace
|
||||
|
||||
void SparseSegmentSqrtNGradCpuKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSqrtNGradInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSqrtNGradOutputsNum, kernel_name_);
|
||||
}
|
||||
|
||||
void SparseSegmentSqrtNGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
x_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
output_dim0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex3);
|
||||
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
}
|
||||
|
||||
bool SparseSegmentSqrtNGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (x_dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (x_dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (x_dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(x_dtype_)
|
||||
<< " which is not supported.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SparseSegmentSqrtNGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kIndex1, std::multiplies<int>()) / x_shape_[kIndex0];
|
||||
size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
size_t num_elements = std::accumulate(y_shape_.begin(), y_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
int32_t k = *reinterpret_cast<int32_t *>(inputs[kIndex3]->addr);
|
||||
auto x_shape_0 = static_cast<int32_t>(x_shape_[kIndex0]);
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
auto indices_addr = reinterpret_cast<int32_t *>(inputs[kIndex1]->addr);
|
||||
auto segment_ids_addr = reinterpret_cast<int32_t *>(inputs[kIndex2]->addr);
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
y_addr[i] = (T)0;
|
||||
}
|
||||
for (size_t i = 1; i < m; i++) {
|
||||
if (segment_ids_addr[i] < segment_ids_addr[i - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
if (indices_addr[i] >= x_shape_0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape.";
|
||||
}
|
||||
if (segment_ids_addr[i] >= k) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of output_dim0.";
|
||||
}
|
||||
}
|
||||
int beginindex = segment_ids_addr[0];
|
||||
size_t countnum = 1;
|
||||
for (size_t i = 1; i < m; i++) {
|
||||
if (segment_ids_addr[i] != beginindex) {
|
||||
for (size_t j = 1; j <= countnum; j++) {
|
||||
for (size_t l = 0; l < n; l++) {
|
||||
y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
beginindex = segment_ids_addr[i];
|
||||
countnum = 1;
|
||||
} else {
|
||||
countnum++;
|
||||
}
|
||||
}
|
||||
|
||||
int i = m;
|
||||
for (size_t j = 1; j <= countnum; j++) {
|
||||
for (size_t l = 0; l < n; l++) {
|
||||
y_addr[indices_addr[i - j] * n + l] += x_addr[beginindex * n + l] / (T)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentSqrtNGradCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {ADD_KERNEL(Float16, Int32, Int32, Int32, Float16),
|
||||
ADD_KERNEL(Float32, Int32, Int32, Int32, Float32),
|
||||
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_GRAD_CPU_KERNEL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentSqrtNGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SparseSegmentSqrtNGradCpuKernelMod() = default;
|
||||
~SparseSegmentSqrtNGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector segment_ids_shape_;
|
||||
ShapeVector output_dim0_shape_;
|
||||
ShapeVector y_shape_;
|
||||
TypeId x_dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_GRAD_CPU_KERNEL_H_
|
|
@ -0,0 +1,171 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_segment_sqrt_n_with_num_segments_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseSegmentSqrtNWithNumSegmentsInputsNum = 4;
|
||||
constexpr size_t kSparseSegmentSqrtNWithNumSegmentsOutputsNum = 1;
|
||||
|
||||
#define ADD_KERNEL(t1, t2, t3, t4, t5) \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##t1) \
|
||||
.AddInputAttr(kNumberType##t2) \
|
||||
.AddInputAttr(kNumberType##t3) \
|
||||
.AddInputAttr(kNumberType##t4) \
|
||||
.AddOutputAttr(kNumberType##t5)
|
||||
} // namespace
|
||||
|
||||
void SparseSegmentSqrtNWithNumSegmentsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kSparseSegmentSqrtNWithNumSegmentsInputsNum, kernel_name_);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kSparseSegmentSqrtNWithNumSegmentsOutputsNum, kernel_name_);
|
||||
xdtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
dtype1_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex1);
|
||||
segment_ids_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex2);
|
||||
num_segments_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndex3);
|
||||
y_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
}
|
||||
|
||||
bool SparseSegmentSqrtNWithNumSegmentsCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
switch (xdtype_) {
|
||||
case (kNumberTypeFloat16):
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float16, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else if (dtype1_ == kNumberTypeInt64) {
|
||||
LaunchKernel<float16, int64_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
<< "', data type of indices, segment_ids and num_segments is " << TypeIdLabel(dtype1_)
|
||||
<< ", which is not supported.";
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeFloat32):
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
LaunchKernel<float, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else if (dtype1_ == kNumberTypeInt64) {
|
||||
LaunchKernel<float, int64_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
<< "', data type of indices, segment_ids and num_segments is " << TypeIdLabel(dtype1_)
|
||||
<< ", which is not supported.";
|
||||
break;
|
||||
}
|
||||
case (kNumberTypeFloat64):
|
||||
if (dtype1_ == kNumberTypeInt32) {
|
||||
LaunchKernel<double, int32_t>(inputs, outputs);
|
||||
break;
|
||||
} else if (dtype1_ == kNumberTypeInt64) {
|
||||
LaunchKernel<double, int64_t>(inputs, outputs);
|
||||
break;
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
<< "', data type of indices, segment_ids and num_segments is " << TypeIdLabel(dtype1_)
|
||||
<< ", which is not supported.";
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', data type of x is " << TypeIdLabel(xdtype_)
|
||||
<< ", which is not supported.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void SparseSegmentSqrtNWithNumSegmentsCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
size_t n = std::accumulate(x_shape_.begin(), x_shape_.end(), kIndex1, std::multiplies<int>()) / x_shape_[kIndex0];
|
||||
size_t m = std::accumulate(segment_ids_shape_.begin(), segment_ids_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
size_t k = std::accumulate(y_shape_.begin(), y_shape_.end(), kIndex1, std::multiplies<int>());
|
||||
auto x_shape_0 = static_cast<T2>(x_shape_[kIndex0]);
|
||||
auto x_addr = reinterpret_cast<T1 *>(inputs[kIndex0]->addr);
|
||||
auto indices_addr = reinterpret_cast<T2 *>(inputs[kIndex1]->addr);
|
||||
auto segment_ids_addr = reinterpret_cast<T2 *>(inputs[kIndex2]->addr);
|
||||
auto num_segments_addr = reinterpret_cast<T2 *>(inputs[kIndex3]->addr);
|
||||
auto y_addr = reinterpret_cast<T1 *>(outputs[kIndex0]->addr);
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
y_addr[i] = (T1)0;
|
||||
}
|
||||
for (size_t i = 1; i < m; i++) {
|
||||
if (segment_ids_addr[i] < segment_ids_addr[i - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
|
||||
}
|
||||
}
|
||||
if (segment_ids_addr[m - 1] >= num_segments_addr[kIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', num_segments must bigger than the last number of segment_ids.";
|
||||
}
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
if (indices_addr[i] >= x_shape_0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape.";
|
||||
}
|
||||
}
|
||||
|
||||
int oldindex = -1;
|
||||
int countnum = 0;
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
if (oldindex == segment_ids_addr[i]) {
|
||||
countnum++;
|
||||
} else {
|
||||
if (countnum != 0) {
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] /= (T1)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
countnum = 1;
|
||||
oldindex = segment_ids_addr[i];
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] = (T1)0;
|
||||
}
|
||||
}
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] += x_addr[j + indices_addr[i] * n];
|
||||
}
|
||||
}
|
||||
if (countnum != 0) {
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
y_addr[j + oldindex * n] /= (T1)(sqrt(countnum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentSqrtNWithNumSegmentsCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
ADD_KERNEL(Float16, Int32, Int32, Int32, Float16), ADD_KERNEL(Float16, Int64, Int64, Int64, Float16),
|
||||
ADD_KERNEL(Float32, Int32, Int32, Int32, Float32), ADD_KERNEL(Float32, Int64, Int64, Int64, Float32),
|
||||
ADD_KERNEL(Float64, Int32, Int32, Int32, Float64), ADD_KERNEL(Float64, Int64, Int64, Int64, Float64)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSegmentSqrtNWithNumSegments,
|
||||
SparseSegmentSqrtNWithNumSegmentsCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentSqrtNWithNumSegmentsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SparseSegmentSqrtNWithNumSegmentsCpuKernelMod() = default;
|
||||
~SparseSegmentSqrtNWithNumSegmentsCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
ShapeVector x_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector segment_ids_shape_;
|
||||
ShapeVector num_segments_shape_;
|
||||
ShapeVector y_shape_;
|
||||
TypeId xdtype_{kTypeUnknown};
|
||||
TypeId dtype1_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SGEMENTS_CPU_KERNEL_H_
|
|
@ -120,6 +120,9 @@ constexpr auto kCross = "Cross";
|
|||
constexpr auto kEditDistance = "EditDistance";
|
||||
constexpr auto kNextAfter = "NextAfter";
|
||||
constexpr auto kSparseSegmentMean = "SparseSegmentMean";
|
||||
constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN";
|
||||
constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
|
||||
constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
|
||||
constexpr auto kTridiagonalMatMul = "TridiagonalMatMul";
|
||||
constexpr auto kFFTWithSize = "FFTWithSize";
|
||||
|
||||
|
@ -1085,6 +1088,10 @@ GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared<Primitive>("Bucketize"))
|
|||
GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared<Primitive>("Einsum"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMean, std::make_shared<Primitive>(kSparseSegmentMean));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared<Primitive>("SparseSegmentSqrtN"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared<Primitive>("SparseSegmentSqrtNGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments,
|
||||
std::make_shared<Primitive>("SparseSegmentSqrtNWithNumSegments"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTridiagonalMatMul, std::make_shared<Primitive>(kTridiagonalMatMul));
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/sparse_segment_sqrt_n_grad.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto output_dim0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank less than 1.";
|
||||
}
|
||||
if (output_dim0_shape.size() != kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar.";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch.";
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
auto output_dim0_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_dim0_value);
|
||||
auto output_dim0_value_ptr = output_dim0_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(output_dim0_value_ptr);
|
||||
auto output_dim0_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
|
||||
size_t dim_zero = output_dim0_value_ptr_tensor[kInputIndex0];
|
||||
if (dim_zero <= kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSqrtNGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat16, kFloat32, kFloat64}, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("output_dim0", output_dim0_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNGrad, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSqrtNGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = kInputIndex4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSqrtNGradInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSqrtNGradInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNGrad, {3});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNGrad, prim::kPrimSparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
|
||||
class MIND_API SparseSegmentSqrtNGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSqrtNGrad);
|
||||
SparseSegmentSqrtNGrad() : BaseOperator(kNameSparseSegmentSqrtNGrad) {
|
||||
InitIOName({"x", "indices", "segment_ids", "output_dim0"}, {"y"});
|
||||
}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseSegmentSqrtNGradInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseSegmentSqrtNGradPtr = std::shared_ptr<SparseSegmentSqrtNGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
#include "ops/sparse_segment_sqrt_n.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
|
||||
kInputIndex1, prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', x's rank less than 1.";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', ranks of indices and segment_ids mismatch.";
|
||||
}
|
||||
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
|
||||
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
|
||||
auto segment_ids_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
|
||||
size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1;
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must >= 0!";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, common_valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("segment_ids", segment_ids_type, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtN, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = kInputIndex3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSqrtNInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSqrtNInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtN, {2});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtN, prim::kPrimSparseSegmentSqrtN, SparseSegmentSqrtNInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_H_
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSqrtN = "SparseSegmentSqrtN";
|
||||
class MIND_API SparseSegmentSqrtN : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSqrtN);
|
||||
SparseSegmentSqrtN() : BaseOperator(kNameSparseSegmentSqrtN) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseSegmentSqrtNPtr = std::shared_ptr<SparseSegmentSqrtN>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_H_
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ops/sparse_segment_sqrt_n_with_num_segments.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto num_segments_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot less than 1.";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices and segment_ids mismatch.";
|
||||
}
|
||||
if (num_segments_shape.size() > kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D.";
|
||||
}
|
||||
if (num_segments_shape.size() == kInputIndex1) {
|
||||
if (num_segments_shape[kInputIndex0] != kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1.";
|
||||
}
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value);
|
||||
auto num_segments_value_ptr = num_segments_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
|
||||
size_t dim_zero = num_segments_value_ptr_tensor.back();
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||
<< ", num_segments must bigger than the last number of segment_ids.";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto num_segments_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("num_segments", num_segments_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32, kInt64}, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNWithNumSegments, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = kInputIndex4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNWithNumSegments, {3});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNWithNumSegments, prim::kPrimSparseSegmentSqrtNWithNumSegments,
|
||||
SparseSegmentSqrtNWithNumSegmentsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SEGMENTS_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SEGMENTS_H_
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
|
||||
class MIND_API SparseSegmentSqrtNWithNumSegments : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSqrtNWithNumSegments);
|
||||
SparseSegmentSqrtNWithNumSegments() : BaseOperator(kNameSparseSegmentSqrtNWithNumSegments) {
|
||||
InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"});
|
||||
}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(
|
||||
const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseSegmentSqrtNWithNumSegmentsPtr = std::shared_ptr<SparseSegmentSqrtNWithNumSegments>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_WITH_NUM_SEGMENTS_H_
|
|
@ -16,6 +16,13 @@
|
|||
"""bprop primitives"""
|
||||
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
||||
from mindspore.common import dtype as mstype
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from .._grad.grad_base import bprop_getters
|
||||
|
||||
|
||||
|
@ -45,3 +52,36 @@ def get_bprop_csr_sparse_matrix_to_sparse_tensor(self):
|
|||
return dx_all
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SparseSegmentSqrtN)
|
||||
def get_bprop_sparse_segment_sqrt_n(self):
|
||||
"""Grad definition for `SparseSegmentSqrtN` operation."""
|
||||
input_grad = G.SparseSegmentSqrtNGrad()
|
||||
shape = P.Shape()
|
||||
|
||||
def bprop(x, indices, segment_ids, out, dout):
|
||||
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
return dx, zeros_like(indices), zeros_like(segment_ids)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SparseSegmentSqrtNWithNumSegments)
|
||||
def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
|
||||
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
|
||||
input_grad = G.SparseSegmentSqrtNGrad()
|
||||
shape = P.Shape()
|
||||
|
||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments))
|
||||
return all_d
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -268,6 +268,9 @@ from .segment_mean import _segment_mean_aicpu
|
|||
from .segment_min import _segment_min_aicpu
|
||||
from .segment_prod import _segment_prod_aicpu
|
||||
from .segment_sum import _segment_sum_aicpu
|
||||
from .sparse_segment_sqrt_n import _sparse_segment_sqrt_n_aicpu
|
||||
from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu
|
||||
from .sparse_segment_sqrt_n_with_num_segments import _sparse_segment_sqrt_n_with_num_segments_aicpu
|
||||
from .scatter_nd_max import _scatter_nd_max_aicpu
|
||||
from .conj import _conj_aicpu
|
||||
from .ctc_loss_v2 import _ctc_loss_v2_aicpu
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseSegmentSqrtN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sparse_segment_sqrt_n_op_info = AiCPURegOp("SparseSegmentSqrtN") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.input(2, "segment_ids", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_segment_sqrt_n_op_info)
|
||||
def _sparse_segment_sqrt_n_aicpu():
|
||||
"""SparseSegmentSqrtN AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseSegmentSqrtNGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sparse_segment_sqrt_n_grad_op_info = AiCPURegOp("SparseSegmentSqrtNGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.input(2, "segment_ids", "required") \
|
||||
.input(3, "output_dim0", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_segment_sqrt_n_grad_op_info)
|
||||
def _sparse_segment_sqrt_n_grad_aicpu():
|
||||
"""SparseSegmentSqrtNGrad AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseSegmentSqrtNWithNumSegments op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sparse_segment_sqrt_n_with_num_segments_op_info = AiCPURegOp("SparseSegmentSqrtNWithNumSegments") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.input(2, "segment_ids", "required") \
|
||||
.input(3, "num_segments", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_segment_sqrt_n_with_num_segments_op_info)
|
||||
def _sparse_segment_sqrt_n_with_num_segments_aicpu():
|
||||
"""SparseSegmentSqrtNWithNumSegments AiCPU register"""
|
||||
return
|
|
@ -3332,6 +3332,46 @@ class MedianGrad(Primitive):
|
|||
self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
|
||||
|
||||
|
||||
class SparseSegmentSqrtNGrad(Primitive):
|
||||
"""
|
||||
Computes gradients for SparseSegmentSqrtNGrad operation.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor.
|
||||
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSqrtN op.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `x` .
|
||||
Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
|
||||
TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
|
||||
TypeError: If the dtype of `indices` is not int32.
|
||||
TypeError: If the dtype of `segment_ids` is not int32.
|
||||
TypeError: If the dtype of `output_dim0` is not int32.
|
||||
ValueError: If dimension size of `x` less than 1.
|
||||
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
||||
ValueError: If dimension size of `output_dim0` is not 0.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If `indices` is out of range of x's first shape.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSqrtNGrad"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
|
||||
|
||||
|
||||
class GridSampler2DGrad(Primitive):
|
||||
"""
|
||||
Computes gradients for GridSampler2D operation.
|
||||
|
|
|
@ -588,6 +588,114 @@ class SparseConcat(Primitive):
|
|||
validator.check_value_type("concat_dim", concat_dim, [int], self.name)
|
||||
|
||||
|
||||
class SparseSegmentSqrtN(Primitive):
|
||||
"""
|
||||
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
|
||||
N is the size of the segment being reduced.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor.
|
||||
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `x` .
|
||||
Has same shape as `x`, except for dimension 0 which is the number of segments.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
|
||||
TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
|
||||
TypeError: If the dtype of `indices` is not int32 or int64.
|
||||
TypeError: If the dtype of `segment_ids` is not int32 or int64.
|
||||
ValueError: If dimension size of `x` less than 1.
|
||||
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If `indices` is out of range of x's first shape.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]]).astype(np.float32))
|
||||
>>> indices = Tensor(np.array([0,1,2]).astype(np.int32))
|
||||
>>> segment_ids = Tensor(np.array([0,1,2]).astype(np.int32))
|
||||
>>> sparse_segment_sqrt_n = SparseSegmentSqrtN()
|
||||
>>> output = sparse_segment_sqrt_n(x, indices, segment_ids)
|
||||
>>> print(output)
|
||||
[[ 1. 2. 3. 4.]
|
||||
[ 5. 6. 7. 8.]
|
||||
[ 9. 10. 11. 12.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSqrtN"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['x', 'indices', 'segment_ids'], outputs=['y'])
|
||||
|
||||
|
||||
class SparseSegmentSqrtNWithNumSegments(Primitive):
|
||||
"""
|
||||
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
|
||||
N is the size of the segment being reduced.
|
||||
Like SparseSegmentSqrtN, but allows missing ids in segment_ids.
|
||||
If an id is missing, the output tensor at that position will be zeroed.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A Tensor.
|
||||
- **indices** (Tensor) - 1-D Tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids: 1-D Tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
- **num_segments** (Tensor) - Num_segments should equal the number of distinct segment_ids.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `x` .
|
||||
Has same shape as `x`, except for dimension 0 which is the value of `num_segments`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor.
|
||||
TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
|
||||
TypeError: If the dtype of `indices` and `segment_ids` and `num_segments` is not int32 or int64.
|
||||
TypeError: If dtype of `segment_ids` and `indices` mismatch.
|
||||
TypeError: If dtype of `num_segments` and `indices` mismatch.
|
||||
ValueError: If dimension size of `x` less than 1.
|
||||
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
|
||||
ValueError: If rank of `num_segments` is bigger than 1.
|
||||
ValueError: If numelements of `num_segments` is not 1.
|
||||
ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`.
|
||||
ValueError: If `indices` is out of range of x's first shape.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
|
||||
>>> indices = Tensor([0, 2, 1], dtype=ms.int32)
|
||||
>>> segment_ids = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
>>> num_segments = Tensor([4], dtype=ms.int32)
|
||||
>>> sparse_segment_sqrt_n_with_num_segments = SparseSegmentSqrtNWithNumSegments()
|
||||
>>> output = sparse_segment_sqrt_n_with_num_segments(x, indices, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[[0. 1. 0. 0.]
|
||||
[1. 0. 1. 0.]
|
||||
[0. 1. 1. 0.]
|
||||
[0. 0. 0. 0.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSqrtNWithNumSegments"""
|
||||
self.init_prim_io_names(
|
||||
inputs=['x', 'indices', 'segment_ids', 'num_segemnts'], outputs=['y'])
|
||||
|
||||
|
||||
class SparseMatrixNNZ(Primitive):
|
||||
r"""
|
||||
Count number of the non-zero elements in sparse matrix or sparse matrixs.
|
||||
|
|
|
@ -125,6 +125,8 @@ from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose
|
|||
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
|
||||
from mindspore.ops.operations.sparse_ops import SparseSparseMinimum
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
||||
from mindspore.ops.operations.other_ops import BlackmanWindow
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent
|
||||
|
@ -2278,6 +2280,19 @@ test_case_math_ops = [
|
|||
'block': P.Erfinv(),
|
||||
'desc_inputs': [Tensor(np.array([0.1, 0.1, 0.1]).astype(np.float16))],
|
||||
'desc_bprop': [Tensor(np.array([1, 1, 1]).astype(np.float16))]}),
|
||||
('SparseSegmentSqrtN', {
|
||||
'block': SparseSegmentSqrtN(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)),
|
||||
Tensor(np.array([0, 1]).astype(np.int32)),
|
||||
Tensor(np.array([0, 1]).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32))]}),
|
||||
('SparseSegmentSqrtNWithNumSegments', {
|
||||
'block': SparseSegmentSqrtNWithNumSegments(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32)),
|
||||
Tensor(np.array([0, 1]).astype(np.int32)),
|
||||
Tensor(np.array([0, 1]).astype(np.int32)),
|
||||
Tensor(np.array([3]).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.array([[1, 2, 4], [2, 4, 5], [2, 2, 6]]).astype(np.float32))]}),
|
||||
('IndexAdd', {
|
||||
'block': IndexAdd(1),
|
||||
'desc_inputs': (Tensor(np.array([0, 1, 2]).astype(np.int32)),
|
||||
|
|
Loading…
Reference in New Issue