!33634 [MS][LITE]add scatter_elements and scatter_add_with_axis cpu kernel

Merge pull request !33634 from mengyuanli/scatter_element
This commit is contained in:
i-robot 2022-04-28 02:17:51 +00:00 committed by Gitee
commit cb1dc9f402
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 983 additions and 1 deletions

View File

@ -0,0 +1,397 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/scatter_elements_cpu_kernel.h"
#include <algorithm>
#include <limits>
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore::kernel {
constexpr size_t kScatterElementsInputsNum = 3;
constexpr size_t kScatterElementsOutputsNum = 1;
namespace {
template <class T>
struct ReductionAdd {
void operator()(T *a, const T &b) const { (*a) += b; }
};
template <class T>
struct ReductionAssignment {
void operator()(T *a, const T &b) const { (*a) = b; }
};
} // namespace
bool ScatterElementsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
MS_ERROR_IF_NULL_W_RET_VAL(base_operator, false);
if (!NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', resize failed.";
return false;
}
kernel_name_ = base_operator->name();
auto input_shape = inputs[kIndex0]->GetShapeVector();
input_dims_ = input_shape.size();
if (input_dims_ < 1) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got " << input_dims_
<< ".";
return false;
}
indices_shape_ = inputs[kIndex1]->GetShapeVector();
auto update_shape = inputs[kIndex2]->GetShapeVector();
if (indices_shape_ != update_shape) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape of 'indice' and the shape of 'update' should be same, but got "
<< "indice shape: " << indices_shape_ << "; "
<< "update shape: " << update_shape << ".";
return false;
}
if (input_dims_ != indices_shape_.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the dimension of 'input_x', 'indice' and 'update' should be same, but got "
<< "input_x dims: " << input_dims_ << "; "
<< "indice dims: " << indices_shape_.size() << "; "
<< "update dims: " << update_shape.size() << ".";
return false;
}
if (base_operator->HasAttr(kAttrAxis)) {
axis_ = GetValue<int64_t>(base_operator->GetAttr(kAttrAxis));
if (axis_ < 0) {
axis_ += input_dims_;
}
}
if (axis_ >= static_cast<int64_t>(input_dims_) || axis_ < 0) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the 'axis' should be less than input dims and greater than or equal 0, but got " << axis_
<< ", while input dims is: " << input_dims_;
return false;
}
for (size_t i = 0; i < input_dims_; ++i) {
if (axis_ != static_cast<int64_t>(i) && input_shape[i] < indices_shape_[i]) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the indices dims should be less than input dims, but got indice dim is: "
<< indices_shape_[i] << " at axis: " << i << ", while input dim is:" << input_shape[i];
return false;
}
}
input_axis_size_ = SizeToInt(input_shape[axis_]);
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
indices_total_num_ =
std::accumulate(indices_shape_.begin(), indices_shape_.end(), size_t(1), std::multiplies<size_t>());
adjusted_indices_.resize(indices_total_num_);
output_dim_stride_.resize(input_dims_);
output_dim_stride_.back() = 1;
for (int i = static_cast<int>(input_dims_ - 2); i >= 0; --i) {
output_dim_stride_[i] = input_shape[i + 1] * output_dim_stride_[i + 1];
}
output_dim_index_.resize(input_dims_);
output_dim_index_.assign(input_dims_, 0);
return true;
}
template <typename S>
bool ScatterElementsCpuKernelMod::AdjustIndices(S *in_indices) {
for (size_t i = 0; i < indices_total_num_; i++) {
auto index = in_indices[i];
if (index >= input_axis_size_ || index < -input_axis_size_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', index: " << index << " is expected to be within bounds ["
<< -input_axis_size_ << ", " << input_axis_size_ << ")";
return false;
}
if (index < 0) {
index += input_axis_size_;
}
adjusted_indices_[i] = index;
}
return true;
}
size_t ScatterElementsCpuKernelMod::ComputeOutoutOffset(const int64_t &index) {
size_t output_offset = 0;
for (size_t i = 0; i < input_dims_; ++i) {
if (static_cast<int64_t>(i) == axis_) {
output_offset += index * output_dim_stride_[i];
} else {
output_offset += output_dim_index_[i] * output_dim_stride_[i];
}
}
return output_offset;
}
void ScatterElementsCpuKernelMod::UpdateOutputDimIndex() {
for (int i = static_cast<int>(input_dims_ - 1); i >= 0; --i) {
auto cur = ++output_dim_index_[i];
if (static_cast<int64_t>(cur) < indices_shape_[i]) {
break;
}
output_dim_index_[i] = 0;
}
return;
}
template <typename T, typename ReductionT>
bool ScatterElementsCpuKernelMod::Scatter(const ReductionT &reduction_func, T *output, const T *updates) {
for (size_t i = 0; i < indices_total_num_;) {
auto index = adjusted_indices_[i];
auto output_offset = ComputeOutoutOffset(index);
reduction_func(output + output_offset, *(updates + i));
if (++i == indices_total_num_) {
break;
}
UpdateOutputDimIndex();
}
return true;
}
template <typename T, typename S, typename ReductionT>
bool ScatterElementsCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterElementsInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterElementsOutputsNum, kernel_name_);
auto *input = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto *indices = reinterpret_cast<S *>(inputs[kIndex1]->addr);
auto *updates = reinterpret_cast<T *>(inputs[kIndex2]->addr);
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
auto bufferSize = outputs[kIndex0]->size;
auto ret = memcpy_s(output, bufferSize, input, input_size_ * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', memory copy failed. Error no: " << ret;
return false;
}
if (!AdjustIndices(indices)) {
return false;
}
ReductionT reduction_func;
return Scatter(reduction_func, output, updates);
}
std::map<std::string, std::vector<std::pair<KernelAttr, ScatterElementsCpuKernelMod::ScatterElementsLaunchFunc>>>
ScatterElementsCpuKernelMod::func_map_ = {
{kScatterElements,
{{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int32_t, ReductionAssignment<int8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int32_t, ReductionAssignment<uint8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int32_t, ReductionAssignment<int32_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int32_t, ReductionAssignment<float16>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ScatterElementsCpuKernelMod::LaunchKernel<float, int32_t, ReductionAssignment<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&ScatterElementsCpuKernelMod::LaunchKernel<double, int32_t, ReductionAssignment<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int32_t, ReductionAssignment<int64_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int64_t, ReductionAssignment<int8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int64_t, ReductionAssignment<uint8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int64_t, ReductionAssignment<int32_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int64_t, ReductionAssignment<float16>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ScatterElementsCpuKernelMod::LaunchKernel<float, int64_t, ReductionAssignment<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&ScatterElementsCpuKernelMod::LaunchKernel<double, int64_t, ReductionAssignment<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int64_t, ReductionAssignment<int64_t>>}}},
{kScatterAddWithAxis,
{{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int32_t, ReductionAdd<int8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int32_t, ReductionAdd<uint8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int32_t, ReductionAdd<int32_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int32_t, ReductionAdd<float16>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ScatterElementsCpuKernelMod::LaunchKernel<float, int32_t, ReductionAdd<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&ScatterElementsCpuKernelMod::LaunchKernel<double, int32_t, ReductionAdd<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int32_t, ReductionAdd<int64_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int64_t, ReductionAdd<int8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int64_t, ReductionAdd<uint8_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int64_t, ReductionAdd<int32_t>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int64_t, ReductionAdd<float16>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&ScatterElementsCpuKernelMod::LaunchKernel<float, int64_t, ReductionAdd<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&ScatterElementsCpuKernelMod::LaunchKernel<double, int64_t, ReductionAdd<double>>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int64_t, ReductionAdd<int64_t>>}}}};
bool ScatterElementsCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL_W_RET_VAL(base_operator, false);
kernel_name_ = base_operator->name();
if (kernel_name_ != kernel_type_) {
MS_LOG(ERROR) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto pair = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!pair.first) {
MS_LOG(ERROR) << "'" << kernel_name_ << "' does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_map_[kernel_name_][pair.second].second;
return true;
}
std::vector<KernelAttr> ScatterElementsCpuKernelMod::GetOpSupport() {
auto iter = func_map_.find(kernel_type_);
if (iter == func_map_.end()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' cpu does not support " << kernel_type_;
}
std::vector<KernelAttr> support_list;
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ScatterElementsLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScatterElements,
[]() { return std::make_shared<ScatterElementsCpuKernelMod>(kScatterElements); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScatterAddWithAxis,
[]() { return std::make_shared<ScatterElementsCpuKernelMod>(kScatterAddWithAxis); });
} // namespace mindspore::kernel

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore::kernel {
constexpr auto kUnKnown = "UnKnown";
constexpr auto kScatterElements = "ScatterElements";
constexpr auto kScatterAddWithAxis = "ScatterAddWithAxis";
class ScatterElementsCpuKernelMod : public NativeCpuKernelMod {
public:
ScatterElementsCpuKernelMod() = default;
explicit ScatterElementsCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ScatterElementsCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
template <typename T, typename S, typename ReductionT>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T, typename ReductionT>
bool Scatter(const ReductionT &reduction_func, T *output, const T *updates);
template <typename S>
bool AdjustIndices(S *in_indices);
size_t ComputeOutoutOffset(const int64_t &index);
void UpdateOutputDimIndex();
private:
using ScatterElementsLaunchFunc = std::function<bool(
ScatterElementsCpuKernelMod *, const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ScatterElementsLaunchFunc>>> func_map_;
ScatterElementsLaunchFunc kernel_func_;
std::string kernel_type_{kUnKnown};
int input_axis_size_{0};
size_t input_size_{1};
size_t indices_total_num_{1};
size_t input_dims_{0};
int64_t axis_{0};
std::vector<int64_t> indices_shape_{};
std::vector<size_t> output_dim_stride_{};
std::vector<size_t> output_dim_index_{};
std::vector<int64_t> adjusted_indices_{};
std::string kernel_name_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_

View File

@ -355,6 +355,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterNdMin, std::make_shared<Primitive>("ScatterNd
GVAR_DEF(PrimitivePtr, kPrimScatterNdMul, std::make_shared<Primitive>("ScatterNdMul"));
GVAR_DEF(PrimitivePtr, kPrimScatterUpdate, std::make_shared<Primitive>("ScatterUpdate"));
GVAR_DEF(PrimitivePtr, kPrimScatterElements, std::make_shared<Primitive>("ScatterElements"));
GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>("ScatterAddWithAxis"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared<Primitive>("TensorScatterAdd"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterSub, std::make_shared<Primitive>("TensorScatterSub"));

View File

@ -0,0 +1,87 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_add_with_axis.h"
#include <map>
#include <set>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterAddWithAxisInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
if (input_x_shape.size() < 1 || indices_shape.size() < 1 || updates_shape.size() < 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", 'input_x_shape', 'indices_shape' and "
<< "'updates_shape' dims should be greater than 1. but got input_x_shape:" << input_x_shape
<< ", indices_shape:" << indices_shape << ", updates_shape: " << updates_shape << ".";
}
if (updates_shape != indices_shape) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "'updates_shape' should be as same as 'indices_shape' but got indices_shape: "
<< indices_shape << ", updates_shape: " << updates_shape << ".";
}
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
TypePtr ScatterAddWithAxisInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
std::set<TypePtr> type_set = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
std::set<TypePtr> check_list(common_valid_types);
check_list.insert(kBool);
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, check_list, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(ScatterAddWithAxis, BaseOperator);
AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterAddWithAxisInferType(primitive, input_args);
auto infer_shape = ScatterAddWithAxisInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
void ScatterAddWithAxis::Init(const int64_t axis) { this->set_axis(axis); }
void ScatterAddWithAxis::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
int64_t ScatterAddWithAxis::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterAddWithAxis, prim::kPrimScatterAddWithAxis, ScatterAddWithAxisInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
#define MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterAddWithAxis = "ScatterAddWithAxis";
/// \brief Updates tensor values by using input indices and value.
/// Refer to Python API @ref mindspore.ops.ScatterAddWithAxis for more details.
class MIND_API ScatterAddWithAxis : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScatterAddWithAxis);
/// \brief Constructor.
ScatterAddWithAxis() : BaseOperator(kNameScatterAddWithAxis) {
InitIOName({"input_x", "indices", "update"}, {"output"});
}
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScatterAddWithAxis for the inputs.
void Init(const int64_t axis = 0);
/// \brief Set axis.
void set_axis(const int64_t axis);
/// \brief Get axis.
///
/// \return axis.
int64_t get_axis() const;
};
abstract::AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimScatterAddWithAxisPtr = std::shared_ptr<ScatterAddWithAxis>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_elements.h"
#include <map>
#include <set>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterElementsInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
if (input_x_shape.size() < 1 || indices_shape.size() < 1 || updates_shape.size() < 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", 'input_x_shape', 'indices_shape' and "
<< "'updates_shape' dims should be greater than 1. but got input_x_shape:" << input_x_shape
<< ", indices_shape:" << indices_shape << ", updates_shape: " << updates_shape << ".";
}
if (updates_shape != indices_shape) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
<< "'updates_shape' should be as same as 'indices_shape' but got indices_shape: "
<< indices_shape << ", updates_shape: " << updates_shape << ".";
}
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
TypePtr ScatterElementsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
std::set<TypePtr> type_set = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
std::set<TypePtr> check_list(common_valid_types);
check_list.insert(kBool);
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, check_list, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(ScatterElements, BaseOperator);
AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterElementsInferType(primitive, input_args);
auto infer_shape = ScatterElementsInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
void ScatterElements::Init(const int64_t axis) { this->set_axis(axis); }
void ScatterElements::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
int64_t ScatterElements::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterElements, prim::kPrimScatterElements, ScatterElementsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
#define MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterElements = "ScatterElements";
/// \brief Updates tensor values by using input indices and value.
/// Refer to Python API @ref mindspore.ops.ScatterElements for more details.
class MIND_API ScatterElements : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScatterElements);
/// \brief Constructor.
ScatterElements() : BaseOperator(kNameScatterElements) { InitIOName({"input_x", "indices", "update"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScatterElements for the inputs.
void Init(const int64_t axis = 0);
/// \brief Set axis.
void set_axis(const int64_t axis);
/// \brief Get axis.
///
/// \return axis.
int64_t get_axis() const;
};
abstract::AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimScatterElementsPtr = std::shared_ptr<ScatterElements>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_

View File

@ -47,7 +47,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
TensorScatterMul, TensorScatterDiv,
ScatterElements, ExtractVolumePatches, LowerBound, UpperBound, Cummax)
ScatterElements, ScatterAddWithAxis, ExtractVolumePatches, LowerBound, UpperBound, Cummax)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
@ -503,6 +503,7 @@ __all__ = [
"TensorScatterMul",
"TensorScatterDiv",
"ScatterElements",
"ScatterAddWithAxis",
"NonZero",
"SoftShrink",
"FFT3D",

View File

@ -7233,6 +7233,58 @@ class ScatterElements(Primitive):
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
class ScatterAddWithAxis(Primitive):
"""
ScatterAddWithAxis takes three inputs data, updates, and indices of the same rank r >= 1
and an optional attribute axis that identifies an axis of data (default is 0).
The output of the operation is produced by creating a copy of the input data, and then add updating its value to
values specified by updates at specific index positions specified by indices.
Args:
axis (int): which axis to scatter, default is 0.
Inputs:
- **data** (Tensor) - The target tensor. c
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape should be equal to indices.shape.
Outputs:
Tensor, has the same shape and type as `data`.
Raises:
TypeError: If dtype of `indices` is neither int32 nor int64.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> op = ops.ScatterAddWithAxis(0)
>>> data = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32)
>>> output = op(data, indices, updates)
>>> print(output)
[[ 2.0 3.0 3.0]
[ 5.0 5.0 7.0]
[ 7.0 9.0 10.0]]
>>> op = ops.ScatterAddWithAxis(1)
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32)
>>> indices = Tensor(np.array([[2, 4]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
>>> output = op(data, indices, updates)
>>> print(output)
[[ 1 2 11 4 13]]
"""
@prim_attr_register
def __init__(self, axis=0):
"""Initialize ScatterAddWithAxis"""
validator.check_value_type("axis", axis, [int], self.name)
self.init_prim_io_names(
inputs=['data', 'indices', 'updates'], outputs=['y'])
class ExtractVolumePatches(Primitive):
"""
Extract patches from input and put them in the "depth" output dimension. 3D extension of extract_image_patches.

View File

@ -0,0 +1,86 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
def scatter_add_with_axis(input_x, indices, updates, axis):
result = input_x.asnumpy().copy()
indices_np = indices.asnumpy().copy()
updates_np = updates.asnumpy().copy()
i_len = indices_np.shape[0]
j_len = indices_np.shape[1]
if axis < 0:
axis += len(result.shape)
for i in range(i_len):
for j in range(j_len):
index = indices_np[i][j]
if index < 0:
index += result.shape[axis]
if axis == 0:
result[index][j] += updates_np[i][j]
if axis == 1:
result[i][index] += updates_np[i][j]
return result
class TestScatterAddWithAxis(nn.Cell):
def __init__(self, input_x, indices, updates, axis):
super(TestScatterAddWithAxis, self).__init__()
self.axis = axis
self.input_x = Parameter(input_x, name="input_x")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
self.scatter_add_with_axis = P.ScatterAddWithAxis(self.axis)
def construct(self):
return self.scatter_add_with_axis(self.input_x, self.indices, self.updates)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32])
@pytest.mark.parametrize('index_dtype', [np.int32, np.int64])
@pytest.mark.parametrize('axis', [0, 1, -1])
def test_scatter_add_with_axis(dtype, index_dtype, axis):
"""
Feature: Op ScatterAddWithAxis
Description: Scatter update value according indices to output.
output[indices[i][j]][j] += updates[i][j] if axis = 0,
output[i][indices[i][j]] += updates[i][j] if axis = 1.
Expectation: Ans is same as expected.
"""
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype))
indices = Tensor(np.array([[1, -1, 2], [0, 2, 1]], dtype=index_dtype))
update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype))
ms_output = TestScatterAddWithAxis(x, indices, update, axis)()
np_output = scatter_add_with_axis(x, indices, update, axis)
print("ms_output:\n", ms_output.asnumpy())
assert np.allclose(ms_output.asnumpy(), np_output)

View File

@ -0,0 +1,80 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Parameter
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
def scatter_element_np(input_x, indices, updates, axis):
result = input_x.asnumpy().copy()
indices_np = indices.asnumpy().copy()
updates_np = updates.asnumpy().copy()
i_len = indices_np.shape[0]
j_len = indices_np.shape[1]
for i in range(i_len):
for j in range(j_len):
if axis == 0:
result[indices_np[i][j]][j] = updates_np[i][j]
if axis == 1:
result[i][indices_np[i][j]] = updates_np[i][j]
return result
class TestScatterElements(nn.Cell):
def __init__(self, input_x, indices, updates, axis):
super(TestScatterElements, self).__init__()
self.axis = axis
self.input_x = Parameter(input_x, name="input_x")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
self.scatter_elements = P.ScatterElements(self.axis)
def construct(self):
return self.scatter_elements(self.input_x, self.indices, self.updates)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32])
@pytest.mark.parametrize('index_dtype', [np.int32])
@pytest.mark.parametrize('axis', [0, 1])
def test_scatter_elements(dtype, index_dtype, axis):
"""
Feature: Op ScatterElements
Description: Scatter update value according indices to output.
output[indices[i][j]][j] = updates[i][j] if axis = 0,
output[i][indices[i][j]] = updates[i][j] if axis = 1.
Expectation: Ans is same as expected.
"""
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype))
indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]], dtype=index_dtype))
update = Tensor(np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
ms_output = TestScatterElements(x, indices, update, axis)()
np_output = scatter_element_np(x, indices, update, axis)
print("ms_output:\n", ms_output.asnumpy())
assert np.allclose(ms_output.asnumpy(), np_output)