!33634 [MS][LITE]add scatter_elements and scatter_add_with_axis cpu kernel
Merge pull request !33634 from mengyuanli/scatter_element
This commit is contained in:
commit
cb1dc9f402
|
@ -0,0 +1,397 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/scatter_elements_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr size_t kScatterElementsInputsNum = 3;
|
||||
constexpr size_t kScatterElementsOutputsNum = 1;
|
||||
|
||||
namespace {
|
||||
template <class T>
|
||||
struct ReductionAdd {
|
||||
void operator()(T *a, const T &b) const { (*a) += b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct ReductionAssignment {
|
||||
void operator()(T *a, const T &b) const { (*a) = b; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool ScatterElementsCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(base_operator, false);
|
||||
if (!NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', resize failed.";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = base_operator->name();
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
input_dims_ = input_shape.size();
|
||||
if (input_dims_ < 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got " << input_dims_
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
indices_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
auto update_shape = inputs[kIndex2]->GetShapeVector();
|
||||
if (indices_shape_ != update_shape) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the shape of 'indice' and the shape of 'update' should be same, but got "
|
||||
<< "indice shape: " << indices_shape_ << "; "
|
||||
<< "update shape: " << update_shape << ".";
|
||||
return false;
|
||||
}
|
||||
if (input_dims_ != indices_shape_.size()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'input_x', 'indice' and 'update' should be same, but got "
|
||||
<< "input_x dims: " << input_dims_ << "; "
|
||||
<< "indice dims: " << indices_shape_.size() << "; "
|
||||
<< "update dims: " << update_shape.size() << ".";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (base_operator->HasAttr(kAttrAxis)) {
|
||||
axis_ = GetValue<int64_t>(base_operator->GetAttr(kAttrAxis));
|
||||
if (axis_ < 0) {
|
||||
axis_ += input_dims_;
|
||||
}
|
||||
}
|
||||
|
||||
if (axis_ >= static_cast<int64_t>(input_dims_) || axis_ < 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the 'axis' should be less than input dims and greater than or equal 0, but got " << axis_
|
||||
<< ", while input dims is: " << input_dims_;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_dims_; ++i) {
|
||||
if (axis_ != static_cast<int64_t>(i) && input_shape[i] < indices_shape_[i]) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the indices dims should be less than input dims, but got indice dim is: "
|
||||
<< indices_shape_[i] << " at axis: " << i << ", while input dim is:" << input_shape[i];
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_axis_size_ = SizeToInt(input_shape[axis_]);
|
||||
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
indices_total_num_ =
|
||||
std::accumulate(indices_shape_.begin(), indices_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||
adjusted_indices_.resize(indices_total_num_);
|
||||
|
||||
output_dim_stride_.resize(input_dims_);
|
||||
output_dim_stride_.back() = 1;
|
||||
for (int i = static_cast<int>(input_dims_ - 2); i >= 0; --i) {
|
||||
output_dim_stride_[i] = input_shape[i + 1] * output_dim_stride_[i + 1];
|
||||
}
|
||||
output_dim_index_.resize(input_dims_);
|
||||
output_dim_index_.assign(input_dims_, 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
bool ScatterElementsCpuKernelMod::AdjustIndices(S *in_indices) {
|
||||
for (size_t i = 0; i < indices_total_num_; i++) {
|
||||
auto index = in_indices[i];
|
||||
if (index >= input_axis_size_ || index < -input_axis_size_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', index: " << index << " is expected to be within bounds ["
|
||||
<< -input_axis_size_ << ", " << input_axis_size_ << ")";
|
||||
return false;
|
||||
}
|
||||
if (index < 0) {
|
||||
index += input_axis_size_;
|
||||
}
|
||||
adjusted_indices_[i] = index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t ScatterElementsCpuKernelMod::ComputeOutoutOffset(const int64_t &index) {
|
||||
size_t output_offset = 0;
|
||||
for (size_t i = 0; i < input_dims_; ++i) {
|
||||
if (static_cast<int64_t>(i) == axis_) {
|
||||
output_offset += index * output_dim_stride_[i];
|
||||
} else {
|
||||
output_offset += output_dim_index_[i] * output_dim_stride_[i];
|
||||
}
|
||||
}
|
||||
return output_offset;
|
||||
}
|
||||
|
||||
void ScatterElementsCpuKernelMod::UpdateOutputDimIndex() {
|
||||
for (int i = static_cast<int>(input_dims_ - 1); i >= 0; --i) {
|
||||
auto cur = ++output_dim_index_[i];
|
||||
if (static_cast<int64_t>(cur) < indices_shape_[i]) {
|
||||
break;
|
||||
}
|
||||
output_dim_index_[i] = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename ReductionT>
|
||||
bool ScatterElementsCpuKernelMod::Scatter(const ReductionT &reduction_func, T *output, const T *updates) {
|
||||
for (size_t i = 0; i < indices_total_num_;) {
|
||||
auto index = adjusted_indices_[i];
|
||||
auto output_offset = ComputeOutoutOffset(index);
|
||||
reduction_func(output + output_offset, *(updates + i));
|
||||
if (++i == indices_total_num_) {
|
||||
break;
|
||||
}
|
||||
UpdateOutputDimIndex();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename ReductionT>
|
||||
bool ScatterElementsCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterElementsInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterElementsOutputsNum, kernel_name_);
|
||||
auto *input = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
auto *indices = reinterpret_cast<S *>(inputs[kIndex1]->addr);
|
||||
auto *updates = reinterpret_cast<T *>(inputs[kIndex2]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
auto bufferSize = outputs[kIndex0]->size;
|
||||
auto ret = memcpy_s(output, bufferSize, input, input_size_ * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', memory copy failed. Error no: " << ret;
|
||||
return false;
|
||||
}
|
||||
if (!AdjustIndices(indices)) {
|
||||
return false;
|
||||
}
|
||||
ReductionT reduction_func;
|
||||
return Scatter(reduction_func, output, updates);
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, ScatterElementsCpuKernelMod::ScatterElementsLaunchFunc>>>
|
||||
ScatterElementsCpuKernelMod::func_map_ = {
|
||||
{kScatterElements,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int32_t, ReductionAssignment<int8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int32_t, ReductionAssignment<uint8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int32_t, ReductionAssignment<int32_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int32_t, ReductionAssignment<float16>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float, int32_t, ReductionAssignment<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<double, int32_t, ReductionAssignment<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int32_t, ReductionAssignment<int64_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int64_t, ReductionAssignment<int8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int64_t, ReductionAssignment<uint8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int64_t, ReductionAssignment<int32_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int64_t, ReductionAssignment<float16>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float, int64_t, ReductionAssignment<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<double, int64_t, ReductionAssignment<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int64_t, ReductionAssignment<int64_t>>}}},
|
||||
{kScatterAddWithAxis,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int32_t, ReductionAdd<int8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int32_t, ReductionAdd<uint8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int32_t, ReductionAdd<int32_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int32_t, ReductionAdd<float16>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float, int32_t, ReductionAdd<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<double, int32_t, ReductionAdd<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int32_t, ReductionAdd<int64_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int8_t, int64_t, ReductionAdd<int8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<uint8_t, int64_t, ReductionAdd<uint8_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int32_t, int64_t, ReductionAdd<int32_t>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float16, int64_t, ReductionAdd<float16>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<float, int64_t, ReductionAdd<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<double, int64_t, ReductionAdd<double>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&ScatterElementsCpuKernelMod::LaunchKernel<int64_t, int64_t, ReductionAdd<int64_t>>}}}};
|
||||
|
||||
bool ScatterElementsCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(base_operator, false);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(ERROR) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto pair = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!pair.first) {
|
||||
MS_LOG(ERROR) << "'" << kernel_name_ << "' does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_map_[kernel_name_][pair.second].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ScatterElementsCpuKernelMod::GetOpSupport() {
|
||||
auto iter = func_map_.find(kernel_type_);
|
||||
if (iter == func_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' cpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ScatterElementsLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScatterElements,
|
||||
[]() { return std::make_shared<ScatterElementsCpuKernelMod>(kScatterElements); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ScatterAddWithAxis,
|
||||
[]() { return std::make_shared<ScatterElementsCpuKernelMod>(kScatterAddWithAxis); });
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kUnKnown = "UnKnown";
|
||||
constexpr auto kScatterElements = "ScatterElements";
|
||||
constexpr auto kScatterAddWithAxis = "ScatterAddWithAxis";
|
||||
|
||||
class ScatterElementsCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ScatterElementsCpuKernelMod() = default;
|
||||
|
||||
explicit ScatterElementsCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ScatterElementsCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
template <typename T, typename S, typename ReductionT>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
template <typename T, typename ReductionT>
|
||||
bool Scatter(const ReductionT &reduction_func, T *output, const T *updates);
|
||||
|
||||
template <typename S>
|
||||
bool AdjustIndices(S *in_indices);
|
||||
|
||||
size_t ComputeOutoutOffset(const int64_t &index);
|
||||
|
||||
void UpdateOutputDimIndex();
|
||||
|
||||
private:
|
||||
using ScatterElementsLaunchFunc = std::function<bool(
|
||||
ScatterElementsCpuKernelMod *, const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ScatterElementsLaunchFunc>>> func_map_;
|
||||
ScatterElementsLaunchFunc kernel_func_;
|
||||
std::string kernel_type_{kUnKnown};
|
||||
int input_axis_size_{0};
|
||||
size_t input_size_{1};
|
||||
size_t indices_total_num_{1};
|
||||
size_t input_dims_{0};
|
||||
int64_t axis_{0};
|
||||
std::vector<int64_t> indices_shape_{};
|
||||
std::vector<size_t> output_dim_stride_{};
|
||||
std::vector<size_t> output_dim_index_{};
|
||||
std::vector<int64_t> adjusted_indices_{};
|
||||
std::string kernel_name_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ELEMENTS_CPU_KERNEL_H_
|
|
@ -355,6 +355,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterNdMin, std::make_shared<Primitive>("ScatterNd
|
|||
GVAR_DEF(PrimitivePtr, kPrimScatterNdMul, std::make_shared<Primitive>("ScatterNdMul"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterUpdate, std::make_shared<Primitive>("ScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterElements, std::make_shared<Primitive>("ScatterElements"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>("ScatterAddWithAxis"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared<Primitive>("TensorScatterAdd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterSub, std::make_shared<Primitive>("TensorScatterSub"));
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/scatter_add_with_axis.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterAddWithAxisInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (input_x_shape.size() < 1 || indices_shape.size() < 1 || updates_shape.size() < 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", 'input_x_shape', 'indices_shape' and "
|
||||
<< "'updates_shape' dims should be greater than 1. but got input_x_shape:" << input_x_shape
|
||||
<< ", indices_shape:" << indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "'updates_shape' should be as same as 'indices_shape' but got indices_shape: "
|
||||
<< indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr ScatterAddWithAxisInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> type_set = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
|
||||
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
|
||||
std::set<TypePtr> check_list(common_valid_types);
|
||||
check_list.insert(kBool);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, check_list, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterAddWithAxis, BaseOperator);
|
||||
AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ScatterAddWithAxisInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterAddWithAxisInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
void ScatterAddWithAxis::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
void ScatterAddWithAxis::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
int64_t ScatterAddWithAxis::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterAddWithAxis, prim::kPrimScatterAddWithAxis, ScatterAddWithAxisInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterAddWithAxis = "ScatterAddWithAxis";
|
||||
/// \brief Updates tensor values by using input indices and value.
|
||||
/// Refer to Python API @ref mindspore.ops.ScatterAddWithAxis for more details.
|
||||
class MIND_API ScatterAddWithAxis : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterAddWithAxis);
|
||||
/// \brief Constructor.
|
||||
ScatterAddWithAxis() : BaseOperator(kNameScatterAddWithAxis) {
|
||||
InitIOName({"input_x", "indices", "update"}, {"output"});
|
||||
}
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScatterAddWithAxis for the inputs.
|
||||
void Init(const int64_t axis = 0);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t axis);
|
||||
/// \brief Get axis.
|
||||
///
|
||||
/// \return axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimScatterAddWithAxisPtr = std::shared_ptr<ScatterAddWithAxis>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/scatter_elements.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterElementsInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (input_x_shape.size() < 1 || indices_shape.size() < 1 || updates_shape.size() < 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", 'input_x_shape', 'indices_shape' and "
|
||||
<< "'updates_shape' dims should be greater than 1. but got input_x_shape:" << input_x_shape
|
||||
<< ", indices_shape:" << indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "'updates_shape' should be as same as 'indices_shape' but got indices_shape: "
|
||||
<< indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr ScatterElementsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> type_set = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
|
||||
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
|
||||
std::set<TypePtr> check_list(common_valid_types);
|
||||
check_list.insert(kBool);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, check_list, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterElements, BaseOperator);
|
||||
AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ScatterElementsInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterElementsInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
void ScatterElements::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
void ScatterElements::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
int64_t ScatterElements::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterElements, prim::kPrimScatterElements, ScatterElementsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterElements = "ScatterElements";
|
||||
/// \brief Updates tensor values by using input indices and value.
|
||||
/// Refer to Python API @ref mindspore.ops.ScatterElements for more details.
|
||||
class MIND_API ScatterElements : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterElements);
|
||||
/// \brief Constructor.
|
||||
ScatterElements() : BaseOperator(kNameScatterElements) { InitIOName({"input_x", "indices", "update"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScatterElements for the inputs.
|
||||
void Init(const int64_t axis = 0);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t axis);
|
||||
/// \brief Get axis.
|
||||
///
|
||||
/// \return axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimScatterElementsPtr = std::shared_ptr<ScatterElements>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
|
@ -47,7 +47,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
|
||||
TensorScatterMul, TensorScatterDiv,
|
||||
ScatterElements, ExtractVolumePatches, LowerBound, UpperBound, Cummax)
|
||||
ScatterElements, ScatterAddWithAxis, ExtractVolumePatches, LowerBound, UpperBound, Cummax)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
|
||||
Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
|
@ -503,6 +503,7 @@ __all__ = [
|
|||
"TensorScatterMul",
|
||||
"TensorScatterDiv",
|
||||
"ScatterElements",
|
||||
"ScatterAddWithAxis",
|
||||
"NonZero",
|
||||
"SoftShrink",
|
||||
"FFT3D",
|
||||
|
|
|
@ -7233,6 +7233,58 @@ class ScatterElements(Primitive):
|
|||
self.init_prim_io_names(inputs=['data', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
|
||||
class ScatterAddWithAxis(Primitive):
|
||||
"""
|
||||
ScatterAddWithAxis takes three inputs data, updates, and indices of the same rank r >= 1
|
||||
and an optional attribute axis that identifies an axis of data (default is 0).
|
||||
The output of the operation is produced by creating a copy of the input data, and then add updating its value to
|
||||
values specified by updates at specific index positions specified by indices.
|
||||
|
||||
Args:
|
||||
axis (int): which axis to scatter, default is 0.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - The target tensor. c
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape should be equal to indices.shape.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `data`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> op = ops.ScatterAddWithAxis(0)
|
||||
>>> data = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32)
|
||||
>>> output = op(data, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 2.0 3.0 3.0]
|
||||
[ 5.0 5.0 7.0]
|
||||
[ 7.0 9.0 10.0]]
|
||||
>>> op = ops.ScatterAddWithAxis(1)
|
||||
>>> data = Tensor(np.array([[1, 2, 3, 4, 5]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[2, 4]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
|
||||
>>> output = op(data, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 1 2 11 4 13]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
"""Initialize ScatterAddWithAxis"""
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.init_prim_io_names(
|
||||
inputs=['data', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
|
||||
class ExtractVolumePatches(Primitive):
|
||||
"""
|
||||
Extract patches from input and put them in the "depth" output dimension. 3D extension of extract_image_patches.
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
def scatter_add_with_axis(input_x, indices, updates, axis):
|
||||
result = input_x.asnumpy().copy()
|
||||
indices_np = indices.asnumpy().copy()
|
||||
updates_np = updates.asnumpy().copy()
|
||||
|
||||
i_len = indices_np.shape[0]
|
||||
j_len = indices_np.shape[1]
|
||||
|
||||
if axis < 0:
|
||||
axis += len(result.shape)
|
||||
|
||||
for i in range(i_len):
|
||||
for j in range(j_len):
|
||||
index = indices_np[i][j]
|
||||
if index < 0:
|
||||
index += result.shape[axis]
|
||||
if axis == 0:
|
||||
result[index][j] += updates_np[i][j]
|
||||
if axis == 1:
|
||||
result[i][index] += updates_np[i][j]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestScatterAddWithAxis(nn.Cell):
|
||||
def __init__(self, input_x, indices, updates, axis):
|
||||
super(TestScatterAddWithAxis, self).__init__()
|
||||
self.axis = axis
|
||||
self.input_x = Parameter(input_x, name="input_x")
|
||||
self.indices = Parameter(indices, name="indices")
|
||||
self.updates = Parameter(updates, name="updates")
|
||||
self.scatter_add_with_axis = P.ScatterAddWithAxis(self.axis)
|
||||
|
||||
def construct(self):
|
||||
return self.scatter_add_with_axis(self.input_x, self.indices, self.updates)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32])
|
||||
@pytest.mark.parametrize('index_dtype', [np.int32, np.int64])
|
||||
@pytest.mark.parametrize('axis', [0, 1, -1])
|
||||
def test_scatter_add_with_axis(dtype, index_dtype, axis):
|
||||
"""
|
||||
Feature: Op ScatterAddWithAxis
|
||||
Description: Scatter update value according indices to output.
|
||||
output[indices[i][j]][j] += updates[i][j] if axis = 0,
|
||||
output[i][indices[i][j]] += updates[i][j] if axis = 1.
|
||||
Expectation: Ans is same as expected.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype))
|
||||
indices = Tensor(np.array([[1, -1, 2], [0, 2, 1]], dtype=index_dtype))
|
||||
update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype))
|
||||
|
||||
ms_output = TestScatterAddWithAxis(x, indices, update, axis)()
|
||||
np_output = scatter_add_with_axis(x, indices, update, axis)
|
||||
print("ms_output:\n", ms_output.asnumpy())
|
||||
assert np.allclose(ms_output.asnumpy(), np_output)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
def scatter_element_np(input_x, indices, updates, axis):
|
||||
result = input_x.asnumpy().copy()
|
||||
indices_np = indices.asnumpy().copy()
|
||||
updates_np = updates.asnumpy().copy()
|
||||
|
||||
i_len = indices_np.shape[0]
|
||||
j_len = indices_np.shape[1]
|
||||
|
||||
for i in range(i_len):
|
||||
for j in range(j_len):
|
||||
if axis == 0:
|
||||
result[indices_np[i][j]][j] = updates_np[i][j]
|
||||
if axis == 1:
|
||||
result[i][indices_np[i][j]] = updates_np[i][j]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestScatterElements(nn.Cell):
|
||||
def __init__(self, input_x, indices, updates, axis):
|
||||
super(TestScatterElements, self).__init__()
|
||||
self.axis = axis
|
||||
self.input_x = Parameter(input_x, name="input_x")
|
||||
self.indices = Parameter(indices, name="indices")
|
||||
self.updates = Parameter(updates, name="updates")
|
||||
self.scatter_elements = P.ScatterElements(self.axis)
|
||||
|
||||
def construct(self):
|
||||
return self.scatter_elements(self.input_x, self.indices, self.updates)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32])
|
||||
@pytest.mark.parametrize('index_dtype', [np.int32])
|
||||
@pytest.mark.parametrize('axis', [0, 1])
|
||||
def test_scatter_elements(dtype, index_dtype, axis):
|
||||
"""
|
||||
Feature: Op ScatterElements
|
||||
Description: Scatter update value according indices to output.
|
||||
output[indices[i][j]][j] = updates[i][j] if axis = 0,
|
||||
output[i][indices[i][j]] = updates[i][j] if axis = 1.
|
||||
Expectation: Ans is same as expected.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype))
|
||||
indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]], dtype=index_dtype))
|
||||
update = Tensor(np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
|
||||
|
||||
ms_output = TestScatterElements(x, indices, update, axis)()
|
||||
np_output = scatter_element_np(x, indices, update, axis)
|
||||
print("ms_output:\n", ms_output.asnumpy())
|
||||
assert np.allclose(ms_output.asnumpy(), np_output)
|
Loading…
Reference in New Issue