forked from mindspore-Ecosystem/mindspore
[feat][assistant][I54KGU] add new aicpu operator IndexPut
This commit is contained in:
parent
03e968dfb9
commit
170910c62d
|
@ -66,6 +66,7 @@ constexpr auto kGather = "Gather";
|
|||
constexpr auto kHistogram = "Histogram";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kIdentityN = "IdentityN";
|
||||
constexpr auto kIndexPut = "IndexPut";
|
||||
constexpr auto kConcatOffset = "ConcatOffset";
|
||||
constexpr auto kConcatOffsetV1 = "ConcatOffsetV1";
|
||||
constexpr auto kRandomChoiceWithMask = "RandomChoiceWithMask";
|
||||
|
@ -298,6 +299,7 @@ const std::set<std::string> kDynamicInputOps{kRaggedTensorToTensor,
|
|||
kReservoirReplayBufferPush,
|
||||
kReservoirReplayBufferSample,
|
||||
kIdentityN,
|
||||
kIndexPut,
|
||||
kSparseConcat,
|
||||
kConcatOffsetV1};
|
||||
const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/index_put_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#include "mindspore/core/ops/index_put.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kIndexPutInputsNum = 3;
|
||||
constexpr size_t kIndexPutOutputsNum = 1;
|
||||
|
||||
#define INDEXPUT_LAUNCH_CASE(DTYPE, TYPE, DTYPE0, INPUTS, OUTPUTS) \
|
||||
case DTYPE: { \
|
||||
if ((DTYPE0) == kNumberTypeInt32) { \
|
||||
LaunchKernel<TYPE, int32_t>(INPUTS, OUTPUTS); \
|
||||
} else { \
|
||||
LaunchKernel<TYPE, int64_t>(INPUTS, OUTPUTS); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<std::vector<int64_t>> IndexPutCpuKernelMod::Transpose(const std::vector<std::vector<int64_t>> &A) {
|
||||
std::vector<std::vector<int64_t>> v;
|
||||
if (A.empty()) {
|
||||
return std::vector<std::vector<int64_t>>();
|
||||
}
|
||||
for (size_t i = 0; i < A[0].size(); ++i) {
|
||||
std::vector<int64_t> k;
|
||||
for (size_t j = 0; j < A.size(); ++j) {
|
||||
k.push_back(A[j][i]);
|
||||
}
|
||||
v.push_back(k);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int64_t IndexPutCpuKernelMod::Multiplicative(const std::vector<int64_t> &tensorshapes, int64_t start, int64_t end) {
|
||||
int64_t result = 1;
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
result *= tensorshapes[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IndexPutCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::IndexPut>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
x1_shape_ = inputs[0]->GetShapeVector();
|
||||
auto type_id = inputs[0]->GetDtype();
|
||||
input_info_.push_back(type_id);
|
||||
x2_shape_ = inputs[1]->GetShapeVector();
|
||||
type_id = inputs[1]->GetDtype();
|
||||
input_info_.push_back(type_id);
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
indices_shape_.push_back(inputs[i]->GetShapeVector());
|
||||
type_id = inputs[i]->GetDtype();
|
||||
input_info_.push_back(type_id);
|
||||
}
|
||||
inputs_nums = inputs.size();
|
||||
accumulate = GetValue<int64_t>(base_operator->GetAttr("accumulate"));
|
||||
return true;
|
||||
}
|
||||
|
||||
void IndexPutCpuKernelMod::CheckParams() {
|
||||
constexpr int indices_start_pos = 2;
|
||||
if (input_info_[0] != input_info_[1]) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
<< "', the x1 and x2 must have the same type, but x1 "
|
||||
"got type with "
|
||||
<< TypeIdLabel(input_info_[0]) << " and x2 got type with " << TypeIdLabel(input_info_[1])
|
||||
<< ".";
|
||||
}
|
||||
for (size_t i = indices_start_pos; i < inputs_nums; i++) {
|
||||
if (input_info_[i] != kNumberTypeInt32 && input_info_[i] != kNumberTypeInt64) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << kernel_name_
|
||||
<< "', the tensors in indices should be the type of int32 or "
|
||||
"int64, but indices["
|
||||
<< i << "] got type with " << TypeIdLabel(input_info_[i]) << ".";
|
||||
}
|
||||
}
|
||||
if (x1_shape_.size() < indices_shape_.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', rank(x1) must be greater than size(indices) but got "
|
||||
<< indices_shape_.size() << " vs " << x1_shape_.size() << ".";
|
||||
}
|
||||
if (x2_shape_.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', rank(x2) must be 1, but got " << x2_shape_.size() << ".";
|
||||
}
|
||||
int64_t maxnum = 0;
|
||||
for (size_t i = 0; i < indices_shape_.size(); i++) {
|
||||
if (indices_shape_[i].size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', rank of indices[" << i << "] must be 1, but got "
|
||||
<< indices_shape_[i].size() << ".";
|
||||
}
|
||||
maxnum = (maxnum < indices_shape_[i][0]) ? indices_shape_[i][0] : maxnum;
|
||||
}
|
||||
for (size_t i = 0; i < indices_shape_.size(); i++) {
|
||||
if (indices_shape_[i][0] != 1 && indices_shape_[i][0] != maxnum) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', tensors of indices should be broadcastable, but indices[" << i << "].shape got "
|
||||
<< indices_shape_[i][0] << ".";
|
||||
}
|
||||
}
|
||||
bool x2_check = x2_shape_[0] != 1 && x2_shape_[0] != maxnum && x2_shape_[0] != x1_shape_[x1_shape_.size() - 1];
|
||||
if (x2_check) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << kernel_name_
|
||||
<< "', the size of x2 must be 1, the max size of the tensors in indices or x1.shape[-1], but got " << x2_shape_[0]
|
||||
<< ".";
|
||||
}
|
||||
if (accumulate != 0 && accumulate != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the values of accumulate should be 0 or 1, but got "
|
||||
<< accumulate << ".";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IndexPutCpuKernelMod::ComputeNospecial(T *x2, size_t x2_nums, std::vector<std::vector<int64_t>> indices_value,
|
||||
T *y, int accumulate) {
|
||||
auto x1_shape = x1_shape_;
|
||||
size_t x1_shape_size = x1_shape.size();
|
||||
size_t idxli = indices_value.size();
|
||||
size_t idxcol = indices_value[0].size();
|
||||
if (x2_nums == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', x2 input illegal, please check!";
|
||||
}
|
||||
for (size_t i = 0; i < idxli; ++i) {
|
||||
size_t offset = 0;
|
||||
for (size_t j = 0; j < idxcol; ++j) {
|
||||
offset += indices_value[i][j] * Multiplicative(x1_shape, j + 1, x1_shape_size);
|
||||
}
|
||||
size_t v_idx = i % x2_nums;
|
||||
y[offset] = (accumulate == 0) ? x2[v_idx] : y[offset] + x2[v_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IndexPutCpuKernelMod::ComputeSpecial(T *x2, size_t x2_nums, std::vector<std::vector<int64_t>> indices_value, T *y,
|
||||
int accumulate) {
|
||||
auto x1_shape = x1_shape_;
|
||||
size_t x1_shape_size = x1_shape.size();
|
||||
size_t idxli = indices_value.size();
|
||||
size_t idxcol = indices_value[0].size();
|
||||
size_t strides = Multiplicative(x1_shape, indices_value.size(), x1_shape_size);
|
||||
if (x2_nums == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', x2 input illegal, please check!";
|
||||
}
|
||||
for (size_t i = 0; i < idxcol; i++) {
|
||||
size_t offset = 0;
|
||||
for (size_t j = 0; j < idxli; j++) {
|
||||
offset += indices_value[j][i] * Multiplicative(x1_shape, j + 1, x1_shape_size);
|
||||
}
|
||||
for (size_t j = 0; j < strides; j++) {
|
||||
size_t y_idx = offset + j;
|
||||
size_t v_idx = j % x2_nums;
|
||||
y[y_idx] = (accumulate == 0) ? x2[v_idx] : y[y_idx] + x2[v_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T0>
|
||||
bool IndexPutCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto *x1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *x2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto *y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t x1_nums =
|
||||
std::accumulate(x1_shape_.begin(), x1_shape_.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
size_t x2_nums =
|
||||
std::accumulate(x2_shape_.begin(), x2_shape_.end(), static_cast<size_t>(1), std::multiplies<size_t>());
|
||||
constexpr size_t indices_start_pos = 2;
|
||||
std::vector<std::vector<int64_t>> indices_value;
|
||||
for (size_t i = indices_start_pos; i < inputs.size(); i++) {
|
||||
auto *linetensor = reinterpret_cast<T0 *>(inputs[i]->addr);
|
||||
std::vector<int64_t> iline;
|
||||
for (size_t j = 0; static_cast<int64_t>(j) < indices_shape_[i - indices_start_pos][0]; j++) {
|
||||
linetensor[j] = (linetensor[j] < 0) ? linetensor[j] + x1_shape_[i - indices_start_pos] : linetensor[j];
|
||||
if (linetensor[j] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices[" << i - indices_start_pos
|
||||
<< "] input illegal "
|
||||
<< ".";
|
||||
}
|
||||
if (linetensor[j] >= x1_shape_[i - indices_start_pos]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', each element in indices[" << i
|
||||
<< "] should be smaller than the value of x1.shape[" << i - indices_start_pos
|
||||
<< "], but got " << linetensor[j] << " and got the value of x1.shape with "
|
||||
<< x1_shape_[i - indices_start_pos] << ".";
|
||||
}
|
||||
iline.push_back(linetensor[j]);
|
||||
}
|
||||
indices_value.push_back(iline);
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
size_t length = (end - start) * sizeof(T);
|
||||
auto ret = memcpy_s(y + start, length, x1 + start, length);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret << ".";
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, x1_nums, this, ¶llel_search_info_);
|
||||
size_t maxl = 0;
|
||||
for (size_t i = 0; i < indices_value.size(); i++) {
|
||||
if (indices_value[i].size() > maxl) {
|
||||
maxl = indices_value[i].size();
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < indices_value.size(); i++) {
|
||||
while (indices_value[i].size() != maxl) {
|
||||
indices_value[i].push_back(indices_value[i][0]);
|
||||
}
|
||||
}
|
||||
if (indices_value.size() == x1_shape_.size()) {
|
||||
std::vector<std::vector<int64_t>> rindices_value = Transpose(indices_value);
|
||||
(void)ComputeNospecial<T>(x2, x2_nums, rindices_value, y, accumulate);
|
||||
} else {
|
||||
(void)ComputeSpecial<T>(x2, x2_nums, indices_value, y, accumulate);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IndexPutCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
constexpr int indices_start_pos = 2;
|
||||
CheckParams();
|
||||
TypeId input_type = input_info_[0];
|
||||
TypeId indices_type = input_info_[indices_start_pos];
|
||||
switch (input_type) {
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeFloat16, float16, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeFloat32, float, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeFloat64, double, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeInt32, int32_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeUInt8, uint8_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeInt16, int16_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeInt8, int8_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeComplex64, std::complex<float>, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeInt64, int64_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeUInt16, uint16_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeComplex128, std::complex<double>, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeUInt32, uint32_t, indices_type, inputs, outputs)
|
||||
INDEXPUT_LAUNCH_CASE(kNumberTypeUInt64, uint64_t, indices_type, inputs, outputs)
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << TypeIdLabel(input_type)
|
||||
<< ".";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, IndexPutCpuKernelMod::KernelRunFunc>> &IndexPutCpuKernelMod::GetFuncList()
|
||||
const {
|
||||
static const std::vector<std::pair<KernelAttr, IndexPutCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr().AddSkipCheckAttr(true), &IndexPutCpuKernelMod::Launch},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, IndexPut, IndexPutCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_PUT_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_PUT_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class IndexPutCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<IndexPutCpuKernelMod> {
|
||||
public:
|
||||
IndexPutCpuKernelMod() = default;
|
||||
~IndexPutCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
void CheckParams();
|
||||
std::vector<std::vector<int64_t>> Transpose(const std::vector<std::vector<int64_t>> &A);
|
||||
int64_t Multiplicative(const std::vector<int64_t> &tensorshapes, int64_t start, int64_t end);
|
||||
template <typename T, typename T0>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void ComputeNospecial(T *x2, size_t x2_nums, std::vector<std::vector<int64_t>> indices_value, T *y, int accumulate);
|
||||
template <typename T>
|
||||
void ComputeSpecial(T *x2, size_t x2_nums, std::vector<std::vector<int64_t>> indices_value, T *y, int accumulate);
|
||||
BaseOperatorPtr base_operator_;
|
||||
std::vector<int64_t> x1_shape_;
|
||||
std::vector<int64_t> x2_shape_;
|
||||
std::vector<std::vector<int64_t>> indices_shape_;
|
||||
size_t inputs_nums = 0;
|
||||
int64_t accumulate{0};
|
||||
std::vector<TypeId> input_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_PUT_CPU_KERNEL_H_
|
|
@ -227,6 +227,7 @@ constexpr auto kTril = "Tril";
|
|||
constexpr auto kEye = "Eye";
|
||||
constexpr auto kTriu = "Triu";
|
||||
constexpr auto kIndexFill = "IndexFill";
|
||||
constexpr auto kIndexPut = "IndexPut";
|
||||
constexpr auto kMeshgrid = "Meshgrid";
|
||||
constexpr auto kScatterNdMax = "ScatterNdMax";
|
||||
constexpr auto kScatterNdMin = "ScatterNdMin";
|
||||
|
@ -1396,6 +1397,7 @@ GVAR_DEF(PrimitivePtr, kPrimMatrixDeterminant, std::make_shared<Primitive>(kMatr
|
|||
GVAR_DEF(PrimitivePtr, kPrimLogMatrixDeterminant, std::make_shared<Primitive>(kLogMatrixDeterminant));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIndexAdd, std::make_shared<Primitive>("IndexAdd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIndexFill, std::make_shared<Primitive>(kIndexFill));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIndexPut, std::make_shared<Primitive>(kIndexPut));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIdentityMath, std::make_shared<Primitive>("Identity", kSideEffectPropagate));
|
||||
GVAR_DEF(PrimitivePtr, kPrimInvGrad, std::make_shared<Primitive>("InvGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimErfinv, std::make_shared<Primitive>("Erfinv"));
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/index_put.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr IndexPutInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x1_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
|
||||
auto x2_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex1);
|
||||
auto x2_shape = x2_shape_ptr->shape();
|
||||
auto x2_rank = SizeToLong(x2_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of x2", x2_rank, kEqual, 1, prim_name);
|
||||
auto idx_shape = input_args[kInputIndex2]->isa<abstract::AbstractTuple>()
|
||||
? input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[kInputIndex2]->cast<abstract::AbstractListPtr>()->elements();
|
||||
auto x1_shape = x1_shape_ptr->shape();
|
||||
int64_t maxsize = 0;
|
||||
for (size_t idx = 0; idx < idx_shape.size(); ++idx) {
|
||||
auto shape = idx_shape[idx]->cast<abstract::AbstractTensorPtr>();
|
||||
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape->BuildShape())[kShape];
|
||||
auto idx_rank = SizeToLong(shape_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of indices[" + std::to_string(idx) + "]", idx_rank, kEqual, 1,
|
||||
prim_name);
|
||||
if (maxsize < shape_shape[0]) {
|
||||
maxsize = shape_shape[0];
|
||||
}
|
||||
}
|
||||
for (size_t idx = 0; idx < idx_shape.size(); ++idx) {
|
||||
auto shape = idx_shape[idx]->cast<abstract::AbstractTensorPtr>();
|
||||
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape->BuildShape())[kShape];
|
||||
if (maxsize != shape_shape[0] && shape_shape[0] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the tensors in indices must be broadcastable, but size of indices[" << idx
|
||||
<< "] got " << shape_shape[0] << ".";
|
||||
}
|
||||
}
|
||||
auto accumulate = GetValue<int64_t>(primitive->GetAttr("accumulate"));
|
||||
if (accumulate != 0 && accumulate != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', accumulate must be 0 or 1, but got " << accumulate
|
||||
<< ".";
|
||||
}
|
||||
if (idx_shape.size() > x1_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', rank(x1) must be greater than size(indices), but got " << x1_shape.size() << " vs "
|
||||
<< idx_shape.size() << ".";
|
||||
} else if (idx_shape.size() < x1_shape.size()) {
|
||||
if (x2_shape[0] != 1 && x2_shape[0] != x1_shape[x1_shape.size() - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the size of x2 must be 1 or x1.shape[-1] if rank(x1) > size(indices), but got "
|
||||
<< x2_shape[0] << ".";
|
||||
}
|
||||
} else {
|
||||
if (x2_shape[0] != 1 && x2_shape[0] != maxsize) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the size of x2 must be 1 or the max size of the tensors in indices if rank(x1) "
|
||||
"== size(indices), but got "
|
||||
<< x2_shape[0] << ".";
|
||||
}
|
||||
}
|
||||
return x1_shape_ptr;
|
||||
}
|
||||
|
||||
TypePtr IndexPutInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64,
|
||||
kInt8, kInt16, kInt32, kInt64, kComplex64, kComplex128};
|
||||
const std::set<TypePtr> idx_valid_types = {kInt32, kInt64};
|
||||
auto x1_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x2_type = input_args[kInputIndex1]->BuildType();
|
||||
if (!input_args[kInputIndex2]->isa<abstract::AbstractTuple>() &&
|
||||
!input_args[kInputIndex2]->isa<abstract::AbstractList>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
<< "', the input indices should be list or tuple of tensors.";
|
||||
}
|
||||
auto idx_type = input_args[kInputIndex2]->isa<abstract::AbstractTuple>()
|
||||
? input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[kInputIndex2]->cast<abstract::AbstractListPtr>()->elements();
|
||||
std::map<std::string, TypePtr> idx_types;
|
||||
for (size_t idx = 0; idx < idx_type.size(); ++idx) {
|
||||
(void)idx_types.emplace("indices[" + std::to_string(idx) + "]:", idx_type[idx]->BuildType());
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(idx_types, idx_valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x1", x1_type);
|
||||
(void)types.emplace("x2", x2_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return x1_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(IndexPut, BaseOperator);
|
||||
AbstractBasePtr IndexPutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kInputsNum = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto type = IndexPutInferType(primitive, input_args);
|
||||
auto shape = IndexPutInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(IndexPut, prim::kPrimIndexPut, IndexPutInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_INDEX_PUT_H_
|
||||
#define MINDSPORE_CORE_OPS_INDEX_PUT_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameIndexPut = "IndexPut";
|
||||
/// \brief Adds values x2 to specified indices of tensor x1.
|
||||
/// Refer to Python API @ref mindspore.ops.IndexPut for more details.
|
||||
class MIND_API IndexPut : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(IndexPut);
|
||||
/// \brief Constructor.
|
||||
IndexPut() : BaseOperator(kNameIndexPut) { InitIOName({"x1", "x2", "indices"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr IndexPutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_INDEX_PUT_H_
|
|
@ -33,6 +33,7 @@ from mindspore.ops.operations.array_ops import Mvlgamma
|
|||
from mindspore.ops.operations.array_ops import Triu
|
||||
from mindspore.ops.operations.array_ops import IdentityN
|
||||
from mindspore.ops.operations.array_ops import IndexFill
|
||||
from mindspore.ops.operations.array_ops import IndexPut
|
||||
from mindspore.ops.operations.array_ops import CheckNumerics
|
||||
from mindspore.ops.operations.array_ops import ConjugateTranspose
|
||||
from mindspore.ops.operations.array_ops import SegmentMax
|
||||
|
@ -47,6 +48,7 @@ from mindspore.ops.operations.array_ops import Im2Col
|
|||
from mindspore.ops.operations.array_ops import Col2Im
|
||||
from mindspore.ops.operations.array_ops import StridedSliceV2
|
||||
from mindspore.ops.operations.array_ops import MaskedScatter
|
||||
from mindspore.ops.operations.array_ops import MaskedSelect
|
||||
from mindspore.ops.operations.array_ops import CountNonZero
|
||||
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
||||
from mindspore.ops.operations.random_ops import LogNormalReverse
|
||||
|
@ -258,6 +260,43 @@ def get_bprop_index_fill(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(IndexPut)
|
||||
def get_bprop_index_put(self):
|
||||
"""Generate bprop for IndexPut"""
|
||||
gather_nd = P.GatherNd()
|
||||
stack = P.Stack()
|
||||
tile = P.Tile()
|
||||
masked_select = MaskedSelect()
|
||||
masked_scatter = MaskedScatter()
|
||||
accumulate_grad = self.accumulate
|
||||
index_put = IndexPut(accumulate=accumulate_grad)
|
||||
is_ascend = context.get_context("device_target") == 'Ascend'
|
||||
|
||||
# Negative value are not supported for GatherNd indices when Ascend, so convert it to positive value.
|
||||
def convert_idx_positive(indices_i, x_shape_i):
|
||||
mask = indices_i < 0
|
||||
idx_pos = masked_select(indices_i + x_shape_i, mask)
|
||||
idx = masked_scatter(indices_i, mask, idx_pos)
|
||||
return idx
|
||||
|
||||
def bprop(x1, x2, indices, out, dout):
|
||||
maxsize = max(x.shape[0] for x in indices)
|
||||
indices_ms = [tile(x, (maxsize,)) if x.shape[0] == 1 else x for x in indices]
|
||||
if is_ascend:
|
||||
indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
|
||||
indices_grad = stack(indices_ms).T
|
||||
values_grad = gather_nd(dout, indices_grad)
|
||||
if x2.shape[0] == 1:
|
||||
values_grad = values_grad.sum().reshape(1)
|
||||
if values_grad.shape != x2.shape and len(indices) < len(x1.shape):
|
||||
_, values_grad = binop_grad_common(x1, x2, dout, values_grad)
|
||||
if accumulate_grad == 0:
|
||||
dout = index_put(dout, zeros_like(x2), indices)
|
||||
return dout, values_grad, [zeros_like(item) for item in indices]
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorScatterSub)
|
||||
def get_bprop_tensor_scatter_sub(self):
|
||||
"""Generate bprop for TensorScatterSub"""
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""IndexPut op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
index_put_op_info = AiCPURegOp("IndexPut") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("accumulate", "int")\
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.input(2, "indices", "dynamic") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(index_put_op_info)
|
||||
def _index_put_aicpu():
|
||||
"""IndexPut aicpu register"""
|
||||
return
|
|
@ -7295,6 +7295,59 @@ class IndexFill(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'dim', 'index', 'value'], outputs=['y'])
|
||||
|
||||
|
||||
class IndexPut(Primitive):
|
||||
r"""
|
||||
According to the index number of indexes, replace the value corresponding to x1 with the value in x2.
|
||||
|
||||
Args:
|
||||
accumulate (int): If accumulate is 1, the elements in x2 are added to x1,
|
||||
else the elements in x2 replace the corresponding element in x1, should be 0 or 1. Default: 0.
|
||||
Inputs:
|
||||
- **x1** (Tensor) - The assigned target tensor, 1-D or higher dimensional.
|
||||
- **x2** (Tensor) - 1-D Tensor of the same type as "x1". if size= 1 will be broadcast
|
||||
- **indices** (tuple[Tensor], list[Tensor]) - the indices of type int32 or int64, used to index into x1.
|
||||
The rank of tensors in indices should be 1-D, size of indices should <= x1.rank and the tensors in indices
|
||||
should be broadcastable.
|
||||
|
||||
Outputs:
|
||||
The Tensor to be assigned. Should be of the same type and shape as "x1".
|
||||
|
||||
Raises:
|
||||
TypeError: If the dtype of `x1` is not equal to the dtype of `x2`.
|
||||
TypeError: If the dtype of `indices` is not tuple[Tensor], list[Tensor].
|
||||
TypeError: If the dtype of tensors in `indices` are not int32 or int64.
|
||||
TypeError: If the dtype of tensors in `indices` are inconsistent.
|
||||
TypeError: If the dtype of `accumulate` are not int.
|
||||
ValueError: If rank(x2) is not 1-D.
|
||||
ValueError: If size(x2) is not 1 or max size of the tensors in `indices` when rank(x1) == size(indices).
|
||||
ValueError: If size(x2) is not 1 or x1.shape[-1] when rank(x1) > size(indices).
|
||||
ValueError: If the rank of tensors in `indices` is not 1-D.
|
||||
ValueError: If the tensors in `indices` is not be broadcastable.
|
||||
ValueError: If size(indices) > rank(x1).
|
||||
ValueError: If `accumulate` is not equal to 0 or 1.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
|
||||
>>> x2 = Tensor(np.array([3]).astype(np.int32))
|
||||
>>> indices = [Tensor(np.array([0, 0]).astype(np.int32)), Tensor(np.array([0, 1]).astype(np.int32))]
|
||||
>>> accumulate = 1
|
||||
>>> op = ops.IndexPut(accumulate = accumulate)
|
||||
>>> output = op(x1, x2, indices)
|
||||
>>> print(output)
|
||||
[[4 5 3]
|
||||
[4 5 6]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, accumulate=0):
|
||||
self.accumulate = accumulate
|
||||
validator.check_value_type('accumulate', accumulate, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['x1', 'x2', 'indices'], outputs=['y'])
|
||||
|
||||
|
||||
class SegmentMax(Primitive):
|
||||
r"""
|
||||
Computes the maximum along segments of a tensor.
|
||||
|
|
|
@ -78,6 +78,7 @@ from mindspore.ops.operations.array_ops import SegmentMin
|
|||
from mindspore.ops.operations.array_ops import SegmentSum
|
||||
from mindspore.ops.operations.array_ops import IdentityN
|
||||
from mindspore.ops.operations.array_ops import IndexFill
|
||||
from mindspore.ops.operations.array_ops import IndexPut
|
||||
from mindspore.ops.operations.array_ops import SegmentMean
|
||||
from mindspore.ops.operations.array_ops import SegmentProd
|
||||
from mindspore.ops.operations.array_ops import ScatterAddWithAxis
|
||||
|
@ -4031,6 +4032,12 @@ test_case_array_ops = [
|
|||
Tensor(4.0, mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32)],
|
||||
}),
|
||||
('IndexPut', {
|
||||
'block': IndexPut(1),
|
||||
'desc_inputs': [(Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], mstype.float32)),
|
||||
(Tensor([3.0], mstype.float32)),
|
||||
(Tensor([0, 1], mstype.int32),)],
|
||||
'desc_bprop': [Tensor([[1, 1, 1], [1, 1, 1]], mstype.float32)]}),
|
||||
('MaskedScatter', {
|
||||
'block': MaskedScatter(),
|
||||
'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]), mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue