forked from mindspore-Ecosystem/mindspore
!39010 [assistant] [ScatterAddWithAxis] Add new AICPU operator ScatterAddWithAxis
Merge pull request !39010 from 谭青林/ScatterAddWithAxis
This commit is contained in:
commit
cac1150345
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/scatter_add_with_axis_cpu_kernel.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <complex>
|
||||
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace {
|
||||
#define DO_COMPUTE_CASE(DTYPE, TYPE, ITYPE, inputs, outputs) \
|
||||
case (DTYPE): { \
|
||||
if ((ITYPE) == kNumberTypeInt32) { \
|
||||
LaunchKernel<TYPE, int32_t>(inputs, outputs); \
|
||||
break; \
|
||||
} else { \
|
||||
LaunchKernel<TYPE, int64_t>(inputs, outputs); \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
#define ADD_KERNEL(t1, t2, t3, t4) \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(kNumberType##t1) \
|
||||
.AddInputAttr(kNumberType##t2) \
|
||||
.AddInputAttr(kNumberType##t3) \
|
||||
.AddOutputAttr(kNumberType##t4)
|
||||
const int32_t kInputNum = 3;
|
||||
const int32_t kOutputNum = 1;
|
||||
const uint32_t kInputIndex2 = 2;
|
||||
const int32_t KSplitSize = 64 * 1024;
|
||||
} // namespace
|
||||
|
||||
void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
x_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
indices_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
if (indices_type_ != kNumberTypeInt32 && indices_type_ != kNumberTypeInt64) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dtype of 'indices' must be int32 or int64, but got "
|
||||
<< indices_type_;
|
||||
}
|
||||
|
||||
// check parameters basic attribution are valid
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
updates_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
|
||||
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
|
||||
|
||||
// Get and check 3 input dim info
|
||||
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
|
||||
int64_t value_dim_num_x2 = static_cast<int64_t>(indices_shape_.size());
|
||||
int64_t value_dim_num_x3 = static_cast<int64_t>(updates_shape_.size());
|
||||
if (value_dim_num_x1 != value_dim_num_x2 || value_dim_num_x2 != value_dim_num_x3) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
|
||||
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
|
||||
<< ", update: " << value_dim_num_x3;
|
||||
}
|
||||
if (axis_ < value_dim_num_x1 * -1 || axis_ >= value_dim_num_x1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of axis is out of range!";
|
||||
}
|
||||
|
||||
int64_t sub_data_fix = 1;
|
||||
int64_t sub_index_fix = 1;
|
||||
for (int64_t i = value_dim_num_x2 - 1; i >= 0; --i) {
|
||||
if (x_shape_[i] < indices_shape_[i] || indices_shape_[i] != updates_shape_[i] || updates_shape_[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the " << i << " dimension verification failed: "
|
||||
<< "input0[" << x_shape_[i] << "], input1[" << indices_shape_[i] << "], input2["
|
||||
<< updates_shape_[i] << "]";
|
||||
}
|
||||
if (i > 0) {
|
||||
sub_data_fix *= x_shape_[i];
|
||||
data_dim_vec_.push_back(sub_data_fix);
|
||||
sub_index_fix *= indices_shape_[i];
|
||||
index_dim_vec_.push_back(sub_index_fix);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ScatterAddWithAxisCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
// check param
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
bool ret = true;
|
||||
switch (x_type_) {
|
||||
DO_COMPUTE_CASE(kNumberTypeFloat16, float16, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeFloat32, float, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeFloat64, double, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeBool, bool, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeInt8, int8_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeInt16, int16_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeInt32, int32_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeInt64, int64_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeUInt16, uint16_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeUInt32, uint32_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeUInt64, uint64_t, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeComplex64, std::complex<float>, indices_type_, inputs, outputs);
|
||||
DO_COMPUTE_CASE(kNumberTypeComplex128, std::complex<double>, indices_type_, inputs, outputs);
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', the data type of input x ["
|
||||
<< TypeIdToType(x_type_)->ToString()
|
||||
<< "] is unsupported. It should be "
|
||||
"float16|float|double|bool|int8|int16|int32|int64|uint8|unint32|"
|
||||
"unit64|complex16|complex32.";
|
||||
ret = false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename TI>
|
||||
void ScatterAddWithAxisCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_x1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
TI *input_x2 = reinterpret_cast<TI *>(inputs[1]->addr);
|
||||
T *input_x3 = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
T *output_y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
|
||||
axis_ = axis_ < 0 ? axis_ + value_dim_num_x1 : axis_;
|
||||
int64_t axis_dim_value = x_shape_[axis_];
|
||||
int64_t initial_size = inputs[0]->size;
|
||||
int64_t total_value_num = initial_size / sizeof(T);
|
||||
int64_t update_value_num = inputs[2]->size / sizeof(T);
|
||||
|
||||
// using input to initial output
|
||||
auto ret = memcpy_s(output_y, outputs[0]->size, input_x1, inputs[0]->size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', launch kernel error: memcpy failed. Error no: " << ret;
|
||||
}
|
||||
// update data by adding "updates" according to indices
|
||||
for (int64_t i = 0; i < update_value_num; ++i) {
|
||||
int64_t remain_index = i;
|
||||
int64_t index_value = 0;
|
||||
int64_t counter = 0;
|
||||
if (input_x2[i] < axis_dim_value * -1 || input_x2[i] >= axis_dim_value) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices value " << input_x2[i] << " is out of bound "
|
||||
<< axis_dim_value << "!";
|
||||
}
|
||||
int64_t input_x2_value = input_x2[i] < 0 ? input_x2[i] + axis_dim_value : input_x2[i];
|
||||
for (int64_t j = static_cast<int64_t>(index_dim_vec_.size()) - 1; j >= 0; --j) {
|
||||
int64_t index_tmp = counter == axis_ ? input_x2_value : remain_index / index_dim_vec_[j];
|
||||
index_value += (index_tmp * data_dim_vec_[j]);
|
||||
remain_index %= index_dim_vec_[j];
|
||||
++counter;
|
||||
}
|
||||
index_value += (counter == axis_ ? input_x2_value : remain_index);
|
||||
if (index_value >= total_value_num) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', update index " << index_value << "greater than "
|
||||
<< total_value_num << "which is overflow!";
|
||||
}
|
||||
output_y[index_value] += input_x3[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ScatterAddWithAxisCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
ADD_KERNEL(UInt8, Int32, UInt8, UInt8), ADD_KERNEL(UInt8, Int64, UInt8, UInt8),
|
||||
ADD_KERNEL(UInt16, Int32, UInt16, UInt16), ADD_KERNEL(UInt16, Int64, UInt16, UInt16),
|
||||
ADD_KERNEL(UInt32, Int32, UInt32, UInt32), ADD_KERNEL(UInt32, Int64, UInt32, UInt32),
|
||||
ADD_KERNEL(UInt64, Int32, UInt64, UInt64), ADD_KERNEL(UInt64, Int64, UInt64, UInt64),
|
||||
ADD_KERNEL(Int8, Int32, Int8, Int8), ADD_KERNEL(Int8, Int64, Int8, Int8),
|
||||
ADD_KERNEL(Int16, Int32, Int16, Int16), ADD_KERNEL(Int16, Int64, Int16, Int16),
|
||||
ADD_KERNEL(Int32, Int32, Int32, Int32), ADD_KERNEL(Int32, Int64, Int32, Int32),
|
||||
ADD_KERNEL(Int64, Int32, Int64, Int64), ADD_KERNEL(Int64, Int64, Int64, Int64),
|
||||
ADD_KERNEL(Float16, Int32, Float16, Float16), ADD_KERNEL(Float16, Int64, Float16, Float16),
|
||||
ADD_KERNEL(Float32, Int32, Float32, Float32), ADD_KERNEL(Float32, Int64, Float32, Float32),
|
||||
ADD_KERNEL(Float64, Int32, Float64, Float64), ADD_KERNEL(Float64, Int64, Float64, Float64)};
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ScatterAddWithAxis, ScatterAddWithAxisCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ScatterAddWithAxisCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
ScatterAddWithAxisCpuKernelMod() = default;
|
||||
~ScatterAddWithAxisCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T, typename TI>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
std::vector<int64_t> x_shape_;
|
||||
std::vector<int64_t> indices_shape_;
|
||||
std::vector<int64_t> updates_shape_;
|
||||
std::vector<int64_t> data_dim_vec_;
|
||||
std::vector<int64_t> index_dim_vec_;
|
||||
int64_t axis_{0};
|
||||
TypeId x_type_;
|
||||
TypeId indices_type_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
|
@ -198,6 +198,7 @@ constexpr auto kIndexFill = "IndexFill";
|
|||
constexpr auto kMeshgrid = "Meshgrid";
|
||||
constexpr auto kScatterNdMax = "ScatterNdMax";
|
||||
constexpr auto kScatterNdMin = "ScatterNdMin";
|
||||
constexpr auto kScatterAddWithAxis = "ScatterAddWithAxis";
|
||||
constexpr auto kCSRSparseMatrixToSparseTensor = "CSRSparseMatrixToSparseTensor";
|
||||
constexpr auto kSlice = "Slice";
|
||||
constexpr auto kAffineGrid = "AffineGrid";
|
||||
|
@ -520,6 +521,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterNdMin, std::make_shared<Primitive>("ScatterNd
|
|||
GVAR_DEF(PrimitivePtr, kPrimScatterNdMul, std::make_shared<Primitive>("ScatterNdMul"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterNdDiv, std::make_shared<Primitive>("ScatterNdDiv"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterUpdate, std::make_shared<Primitive>("ScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>(kScatterAddWithAxis));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterElements, std::make_shared<Primitive>("TensorScatterElements"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared<Primitive>("TensorScatterAdd"));
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/scatter_add_with_axis.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterAddWithAxisInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
auto indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
auto updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
auto updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
if (input_x_shape.size() < 1 || indices_shape.size() < 1 || updates_shape.size() < 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", 'input_x_shape', 'indices_shape' and "
|
||||
<< "'updates_shape' dims must be greater than 1. but got input_x_shape:" << input_x_shape
|
||||
<< ", indices_shape:" << indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
if (updates_shape != indices_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", "
|
||||
<< "'updates_shape' must be as same as 'indices_shape' but got "
|
||||
"indices_shape: "
|
||||
<< indices_shape << ", updates_shape: " << updates_shape << ".";
|
||||
}
|
||||
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr ScatterAddWithAxisInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> type_set = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_args[kInputIndex0]->BuildType());
|
||||
type_dict.emplace("updates", input_args[kInputIndex2]->BuildType());
|
||||
std::set<TypePtr> check_list(common_valid_types);
|
||||
check_list.insert(kBool);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, check_list, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = ScatterAddWithAxisInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterAddWithAxisInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterAddWithAxis, BaseOperator);
|
||||
void ScatterAddWithAxis::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
void ScatterAddWithAxis::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
int64_t ScatterAddWithAxis::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterAddWithAxis, prim::kPrimScatterAddWithAxis, ScatterAddWithAxisInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterAddWithAxis = "ScatterAddWithAxis";
|
||||
/// \brief Updates tensor values by using input indices and value.
|
||||
/// Refer to Python API @ref mindspore.ops.ScatterAddWithAxis for more details.
|
||||
class MIND_API ScatterAddWithAxis : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterAddWithAxis);
|
||||
/// \brief Constructor.
|
||||
ScatterAddWithAxis() : BaseOperator(kNameScatterAddWithAxis) { InitIOName({"input_x", "indices", "update"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref
|
||||
/// mindspore.ops.ScatterAddWithAxis for the inputs.
|
||||
void Init(const int64_t axis = 0);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t axis);
|
||||
/// \brief Get axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimScatterAddWithAxisPtr = std::shared_ptr<ScatterAddWithAxis>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
"""array_ops"""
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from ...common import dtype as mstype
|
||||
|
@ -37,6 +36,7 @@ from ..operations.array_ops import SegmentMax
|
|||
from ..operations.array_ops import SegmentMin
|
||||
from ..operations.array_ops import SegmentSum
|
||||
from ..operations.array_ops import TensorScatterElements
|
||||
from ..operations.array_ops import ScatterAddWithAxis
|
||||
from ..operations.array_ops import Expand
|
||||
from ..operations.array_ops import SegmentMean
|
||||
from ..operations.array_ops import AffineGrid
|
||||
|
@ -636,6 +636,33 @@ def get_bprop_tensor_scatter_elements(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(ScatterAddWithAxis)
|
||||
def get_bprop_scatter_add_with_axis(self):
|
||||
"""Generate bprop for ScatterAddWithAxis"""
|
||||
gather_d = P.GatherD()
|
||||
slice_op = P.Slice()
|
||||
axis = self.axis
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
dout_shape = dout.shape
|
||||
index_shape = indices.shape
|
||||
if dout_shape != index_shape:
|
||||
pad_list = []
|
||||
slice_list = []
|
||||
for i, pos in enumerate(dout_shape):
|
||||
pad_list.append((0, pos - index_shape[i]))
|
||||
slice_list.append(0)
|
||||
pad_tuple = tuple(pad_list)
|
||||
out_index = P.Pad(pad_tuple)(indices)
|
||||
out_gather = gather_d(dout, axis, out_index)
|
||||
update_grad = slice_op(out_gather, slice_list, index_shape)
|
||||
else:
|
||||
update_grad = gather_d(dout, axis, indices)
|
||||
return dout, zeros_like(indices), update_grad
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Expand)
|
||||
def get_bprop_expand(self):
|
||||
"""Generate bprop for Expand"""
|
||||
|
|
|
@ -50,6 +50,7 @@ from .gather_nd import _gather_nd_aicpu
|
|||
from .scatter_nd import _scatter_nd_aicpu
|
||||
from .scatter_nd_update import _scatter_nd_update_aicpu
|
||||
from .scatter import _scatter_aicpu
|
||||
from .scatter_add_with_axis import _scatter_add_with_axis_aicpu
|
||||
from .exp import _exp_aicpu
|
||||
from .expm1 import _expm1_aicpu
|
||||
from .identity import _identity_aicpu
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ScatterAddWithAxis op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
scatter_add_with_axis_op_info = AiCPURegOp("ScatterAddWithAxis") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("axis", "int") \
|
||||
.input(0, "input_x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.input(2, "updates", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_add_with_axis_op_info)
|
||||
def _scatter_add_with_axis_aicpu():
|
||||
"""ScatterAddWithAxis AiCPU register"""
|
||||
return
|
|
@ -7127,6 +7127,65 @@ class ExtractVolumePatches(Primitive):
|
|||
self.add_prim_attr("padding", self.padding)
|
||||
|
||||
|
||||
class ScatterAddWithAxis(Primitive):
|
||||
"""
|
||||
ScatterAddWithAxis takes three inputs input_x, updates, and indices of the same rank r >= 1
|
||||
and an optional attribute axis that identifies an axis of input_x (default is 0).
|
||||
The output of the operation is produced by creating a copy of the input input_x, and then
|
||||
add updating its value to values specified by updates at specific index positions specified
|
||||
by indices.
|
||||
|
||||
Args:
|
||||
axis (int): which axis to scatter, default is 0.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor.
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape should be equal to indices.shape.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If the shape of `indices` is not equal to the shape of `update`
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> op = ops.ScatterAddWithAxis(0)
|
||||
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32)
|
||||
>>> output = op(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 2. 3. 3.]
|
||||
[ 5. 5. 7.]
|
||||
[ 7. 9. 10.]]
|
||||
>>> op = ops.ScatterAddWithAxis(1)
|
||||
>>> input_x = Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
|
||||
>>> output = op(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 1 2 11 4 13]]
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('updates', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
"""Initialize ScatterAddWithAxis"""
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_x', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
|
||||
class Lstsq(Primitive):
|
||||
r"""
|
||||
Computes the solutions of the least squares and minimum norm problems of full-rank
|
||||
|
|
|
@ -63,6 +63,7 @@ from mindspore.ops.operations.array_ops import IdentityN
|
|||
from mindspore.ops.operations.array_ops import IndexFill
|
||||
from mindspore.ops.operations.array_ops import SegmentMean
|
||||
from mindspore.ops.operations.array_ops import SegmentProd
|
||||
from mindspore.ops.operations.array_ops import ScatterAddWithAxis
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.ops.operations.random_ops import TruncatedNormal
|
||||
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
|
||||
|
@ -630,6 +631,19 @@ class ScatterNdAdd(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterAddWithAxisNet(nn.Cell):
|
||||
"""ScatterAddWithAxis net definition"""
|
||||
|
||||
def __init__(self, ref_shape, axis, dtype=np.float32):
|
||||
super(ScatterAddWithAxisNet, self).__init__()
|
||||
self.scatter_add_with_axis = ScatterAddWithAxis(axis)
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_add_with_axis(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterNdMaxNet(nn.Cell):
|
||||
"""ScatterNdMax net definition"""
|
||||
|
||||
|
@ -3658,6 +3672,12 @@ test_case_array_ops = [
|
|||
Tensor(4.0, mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32)],
|
||||
}),
|
||||
('ScatterAddWithAxis', {
|
||||
'block': ScatterAddWithAxisNet((3, 3), 0),
|
||||
'desc_inputs': [Tensor(np.array([[0, 1], [1, 2]]), mstype.int32),
|
||||
Tensor(np.ones([2, 2]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.ones([3, 3]), mstype.float32)],
|
||||
}),
|
||||
('MatrixDiag', {
|
||||
'block': inner.MatrixDiag(),
|
||||
'desc_inputs': [Tensor(np.array([1, -1]), mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue