IndexAdd supports CPU, add function interface for IndexAdd

This commit is contained in:
looop5 2022-04-22 15:29:50 +08:00
parent 0577daefd3
commit 521ed19042
10 changed files with 733 additions and 11 deletions

View File

@ -308,6 +308,7 @@ Parameter操作算子
mindspore.ops.assign
mindspore.ops.assign_add
mindspore.ops.assign_sub
mindspore.ops.index_add
.. list-table::
:widths: 50 50

View File

@ -0,0 +1,28 @@
mindspore.ops.index_add
=================
.. py:function:: mindspore.ops.index_add(x, indices, y, axis, use_lock=True, check_index_bound=True)
将Tensor y加到Parameter x的指定axis轴的指定indices位置。要求axis轴的取值范围为[0, len(x.dim) - 1]indices元素的取值范围
为[0, x.shape[axis] - 1]。
**参数:**
- **x** (Parameter) - 被加Parameter。
- **indices** (Tensor) - 指定Tensor y加到`x`的指定axis轴的指定indices位置。
- **y** (Tensor) - 与`x`相加的Tensor。
- **axis** (int) - 指定Tensor y加到`x`的指定axis轴。
- **use_lock** (bool) - 计算时使用锁。默认值True。
- **check_index_bound** (bool) - indices边界检查。默认值True。
**返回:**
相加后的Tensor。shape和数据类型与输入 `x`相同。
**异常:**
- **TypeError** - `indices`或者`y`的类型不是Tensor。
- **ValueError** - `axis`的值超出`x` shape的维度范围。
- **ValueError** - `x` shape的维度和`y` shape的维度不一致。
- **ValueError** - `indices` shape的维度不是一维或者`indices` shape的大小与`y` shape在`axis`轴上的大小不一致。
- **ValueError** - 除`axis`轴外,`x` shape和`y` shape的大小不一致。

View File

@ -0,0 +1,224 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/index_add_cpu_kernel.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <map>
#include "mindspore/core/ops/index_add.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kIndexAddInputsNum = 3;
constexpr size_t kIndexAddOutputsNum = 1;
} // namespace
bool IndexAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::IndexAdd>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast IndexAdd ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
if (inputs.size() != kIndexAddInputsNum || outputs.size() != kIndexAddOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output tensor number should be " << kIndexAddInputsNum
<< " and " << kIndexAddOutputsNum << ", but got " << inputs.size() << " and " << outputs.size();
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "IndexAdd does not support this kernel data type: " << kernel_attr;
return false;
}
base_operator_ = base_operator;
kernel_func_ = func_list_[index].second;
return true;
}
bool IndexAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
if (!NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) {
MS_LOG(WARNING) << kernel_name_ << " resize failed.";
return false;
}
// Get input, output and attr info
x_shape_ = inputs[kIndex0]->GetShapeVector();
y_shape_ = inputs[kIndex2]->GetShapeVector();
indices_shape_ = inputs[kIndex1]->GetShapeVector();
axis_ = GetValue<int64_t>(base_operator_->GetAttr(AXIS));
return true;
}
void IndexAddCpuKernelMod::CheckParams() {
// Check dimension(x) = dimension(y)
if (x_shape_.size() != y_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'x' and 'y' should have the same dimension, but got "
<< x_shape_.size() << " vs " << y_shape_.size();
}
// Check dimension(indices) = 1
if (indices_shape_.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'indices' should has one dimension, but got "
<< indices_shape_.size();
}
// Check axis's value is valid
auto x_rank = SizeToLong(x_shape_.size());
if (axis_ < -x_rank || axis_ >= x_rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' should be in range [" << -x_rank << ", " << x_rank
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += x_rank;
}
auto axis = LongToSize(axis_);
// Check indices's size = y.shape[axis]
if (indices_shape_[0] != y_shape_[axis]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< ", size of 'indices' should be same as size of 'y' in 'axis'th dimension, but got "
<< indices_shape_[0] << " vs " << y_shape_[axis];
}
// Check x.shape[i] = y.shape[i], except i = axis
x_nums_ = 1;
y_nums_ = 1;
inner_size_ = 1;
for (size_t i = 0; i < x_shape_.size(); ++i) {
if (x_shape_[i] <= 0 || y_shape_[i] <= 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'x' shape[" << i << "] or 'y' shape [" << i
<< "] is invalid, which should > 0, but got " << x_shape_[i] << " and " << y_shape_[i];
}
if (i != axis && x_shape_[i] != y_shape_[i]) {
MS_LOG(EXCEPTION)
<< "For '" << kernel_name_
<< ", the shape of 'x' and 'y' must be same except the 'axis'th dimension, but got different values: "
<< x_shape_[i] << " vs " << y_shape_[i] << " in dimension " << i;
}
x_nums_ *= LongToSize(x_shape_[i]);
y_nums_ *= LongToSize(y_shape_[i]);
if (i > axis) {
inner_size_ *= LongToSize(x_shape_[i]);
}
}
x_axis_size_ = LongToSize(x_shape_[axis_]);
y_axis_size_ = LongToSize(y_shape_[axis_]);
}
template <typename T>
bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIndexAddOutputsNum, kernel_name_);
auto *x = reinterpret_cast<T *>(inputs[0]->addr);
int *indices = reinterpret_cast<int *>(inputs[1]->addr);
auto *y = reinterpret_cast<T *>(inputs[2]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
CheckParams();
size_t x_axis_inner_size = x_axis_size_ * inner_size_;
size_t y_axis_inner_size = y_axis_size_ * inner_size_;
auto task1 = [&](const size_t start, const size_t end) {
for (size_t i = start; i < end; ++i) {
// calc idx_y in y.shape[axis]
const size_t y_axis_idx = (i / inner_size_) % y_axis_size_;
// calc idx_x in x.shape[axis]
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
// only process add operation when idx_x is valid
if (x_axis_idx < x_axis_size_) {
const size_t x_outer_idx = i / y_axis_inner_size;
const size_t x_inner_idx = i % inner_size_;
const size_t x_idx = x_outer_idx * x_axis_inner_size + x_axis_idx * inner_size_ + x_inner_idx;
x[x_idx] += y[i];
}
}
};
ParallelLaunchAutoSearch(task1, y_nums_, this, &parallel_search_info_);
auto task2 = [&](size_t start, size_t end) {
size_t length = (end - start) * sizeof(T);
int ret = memcpy_s(output + start, length, x + start, length);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy_s error. Error no: " << ret;
}
};
ParallelLaunchAutoSearch(task2, x_nums_, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, IndexAddCpuKernelMod::IndexAddFunc>> IndexAddCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&IndexAddCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&IndexAddCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&IndexAddCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&IndexAddCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&IndexAddCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&IndexAddCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&IndexAddCpuKernelMod::LaunchKernel<uint8_t>}};
std::vector<KernelAttr> IndexAddCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, IndexAddFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, IndexAdd, IndexAddCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_ADD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_ADD_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class IndexAddCpuKernelMod : public NativeCpuKernelMod {
public:
IndexAddCpuKernelMod() = default;
~IndexAddCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParams();
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using IndexAddFunc =
std::function<bool(IndexAddCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, IndexAddFunc>> func_list_;
IndexAddFunc kernel_func_{nullptr};
BaseOperatorPtr base_operator_;
std::vector<int64_t> x_shape_;
std::vector<int64_t> y_shape_;
std::vector<int64_t> indices_shape_;
int64_t axis_{0};
size_t x_nums_{1};
size_t y_nums_{1};
size_t inner_size_{1};
size_t x_axis_size_{1};
size_t y_axis_size_{1};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_INDEX_ADD_CPU_KERNEL_H_

View File

@ -33,27 +33,44 @@ abstract::ShapePtr IndexAddInferShape(const PrimitivePtr &primitive, const std::
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
auto idx_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex1);
auto y_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex2);
auto x_is_dynamic = x_shape_ptr->IsDynamic();
auto idx_is_dynamic = idx_shape_ptr->IsDynamic();
auto y_is_dynamic = y_shape_ptr->IsDynamic();
if (x_is_dynamic) {
return x_shape_ptr;
}
auto x_shape = x_shape_ptr->shape();
auto y_shape = y_shape_ptr->shape();
auto x_rank = SizeToLong(x_shape.size());
auto y_rank = SizeToLong(y_shape.size());
CheckAndConvertUtils::Check("x rank", x_rank, kEqual, y_rank, prim_name);
if (!y_is_dynamic) {
CheckAndConvertUtils::Check("x rank", x_rank, kEqual, y_rank, prim_name);
}
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeNeither, {-x_rank - 1, x_rank}, prim_name);
auto idx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto idx_shape = idx_shape_ptr->shape();
auto idx_rank = SizeToLong(idx_shape.size());
(void)CheckAndConvertUtils::CheckInteger("idx size", idx_rank, kEqual, 1, prim_name);
auto axis_rank = axis;
if (axis < 0) {
axis_rank = axis + x_rank;
}
(void)CheckAndConvertUtils::Check("size of indices", idx_shape[0], kEqual, y_shape[axis_rank], prim_name);
if (y_is_dynamic) {
return x_shape_ptr;
}
if (!idx_is_dynamic) {
(void)CheckAndConvertUtils::Check("size of indices", idx_shape[0], kEqual, y_shape[axis_rank], prim_name);
}
for (int dim = 0; dim < x_rank; dim = dim + 1) {
if (dim != axis_rank) {
(void)CheckAndConvertUtils::Check("x dim", x_shape[dim], kEqual, y_shape[dim], prim_name);
}
}
return std::make_shared<abstract::Shape>(x_shape);
return x_shape_ptr;
}
TypePtr IndexAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill)
from .parameter_func import assign, assign_add, assign_sub
from .parameter_func import assign, assign_add, assign_sub, index_add
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
tensor_floordiv, floor_div, floordiv, tensor_pow, pow, pows, tensor_mod, floor_mod, floormod,

View File

@ -142,9 +142,59 @@ def assign_add(variable, value):
return assign_add_(variable, value)
def index_add(x, indices, y, axis, use_lock=True, check_index_bound=True):
"""
Adds tensor `y` to specified axis and indices of Parameter `x`. The axis should be in [0, len(x.dim) - 1],
and indices should be in [0, the size of `x` - 1] at the axis dimension.
Args:
x (Parameter): The input Parameter to add to.
indices (Tensor): Add the value of `x` and `y` along the dimension of the `axis` according to the
specified index value, with data type int32.
The `indices` must be 1D with the same size as the size of `y` in the `axis` dimension. The values
of `indices` should be in [0, b), where the b is the size of `x` in the `axis` dimension.
y (Tensor): The input tensor with the value to add. Must have same data type as `x`.
The shape must be the same as `x` except the `axis` th dimension.
axis (int): The dimension along which to index.
use_lock (bool): If true, use lock mode. If false, don't use lock mode. Default: True.
check_index_bound (bool): If true, check index boundary. If false, don't check index boundary. Default: True.
Returns:
Tensor, has the same shape and dtype as `x`.
Raises:
TypeError: If `x` is not a Parameter.
TypeError: If neither `indices` nor `y` is a Tensor.
ValueError: If shape of `indices` is not one dimension.
ValueError: If axis is out of `x` rank's range.
ValueError: If `x` rank is not the same as `y` rank.
ValueError: If shape of `indices` is not 1D or size of `indices` is not equal to dimension of y[axis].
ValueError: If `y`'s shape is not the same as `x` except the `axis` th dimension.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter
>>> from mindspore import ops
>>> x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32), name="name_x")
>>> indices = Tensor(np.array([0, 2]), mindspore.int32)
>>> y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
>>> output = ops.index_add(x, indices, y, 1)
>>> print(output)
[[ 1.5 2. 4. ]
[ 5. 5. 7.5]
[ 9. 8. 11.5]]
"""
return P.IndexAdd(axis, use_lock, check_index_bound)(x, indices, y)
__all__ = [
'assign',
'assign_sub',
'assign_add'
'assign_add',
'index_add'
]
__all__.sort()

View File

@ -978,6 +978,7 @@ tensor_operator_registry.register('tile', P.Tile)
tensor_operator_registry.register('logical_not', P.LogicalNot)
tensor_operator_registry.register('sum', P.ReduceSum)
tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('index_add', P.IndexAdd)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -4979,7 +4979,7 @@ class LogMatrixDeterminant(Primitive):
class IndexAdd(Primitive):
"""
Adds tensor `y` to specified axis and indices of tensor `x`. The axis should be in [0, len(x.dim) - 1],
and indices should be in [0, the size of `x`] at the axis dimension.
and indices should be in [0, the size of `x` - 1] at the axis dimension.
Args:
axis (int): The dimension along which to index.
@ -5003,11 +5003,11 @@ class IndexAdd(Primitive):
TypeError: If neither `indices` nor `y` is a Tensor.
ValueError: If axis is out of `x` rank's range.
ValueError: If `x` rank is not the same as `y` rank.
ValueError: If size of `indices` is not equal to dimension of y[axis].
ValueError: If shape of `indices` is not 1D or size of `indices` is not equal to dimension of y[axis].
ValueError: If `y`'s shape is not the same as `x` except the `axis` th dimension.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> class Net(nn.Cell):

View File

@ -0,0 +1,326 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, ParameterTuple
class NetIndexAdd(nn.Cell):
def __init__(self, x, axis):
super(NetIndexAdd, self).__init__()
self.input_x = Parameter(Tensor(x), name='x')
self.index_add = ops.IndexAdd(axis)
def construct(self, idx, y):
return self.index_add(self.input_x, idx, y)
def index_add_forward(nptype):
x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(nptype)
y = np.ones((2, 2, 4), dtype=nptype)
idx = np.array([0, 2]).astype(np.int32)
axis = 1
expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y
net = NetIndexAdd(x, axis)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_float64():
"""
Feature: test IndexAdd forward.
Description: test float64 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.float64)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_float16():
"""
Feature: test IndexAdd forward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_int32():
"""
Feature: test IndexAdd forward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_int16():
"""
Feature: test IndexAdd forward.
Description: test int16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.int16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.int16)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_int8():
"""
Feature: test IndexAdd forward.
Description: test int8 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.int8)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.int8)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_uint8():
"""
Feature: test IndexAdd forward.
Description: test uint8 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_forward(np.uint8)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_forward(np.uint8)
class IndexAddGradNet(nn.Cell):
def __init__(self, network):
super(IndexAddGradNet, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, idx, y, dout):
out = self.grad(self.network, self.params)(idx, y, dout)
return out
def index_add_grad_with_type(nptype):
x = np.arange(15).reshape(5, 3).astype(nptype)
net = NetIndexAdd(x, 1)
grad_net = IndexAddGradNet(net)
y = Tensor(np.arange(5).reshape(5, 1).astype(nptype))
dout = Tensor(np.array([[63., 64., 65.],
[66., 67., 68.],
[69., 70., 71.],
[72., 73., 74.],
[75., 76., 77.]]).astype(nptype))
index = Tensor(np.array([1]), dtype=mindspore.int32)
output = grad_net(index, y, dout)
ygrad = output[0][1]
xgrad = output[1][0]
expect_xgrad = np.array([[63., 64., 65.],
[66., 67., 68.],
[69., 70., 71.],
[72., 73., 74.],
[75., 76., 77.]]).astype(nptype)
expect_ygrad = np.array([[64.],
[67.],
[70.],
[73.],
[76.]]).astype(nptype)
np.testing.assert_array_equal(xgrad.asnumpy(), expect_xgrad)
np.testing.assert_array_equal(ygrad.asnumpy(), expect_ygrad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_float64():
"""
Feature: test IndexAdd backward.
Description: test float64 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.float64)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_float32():
"""
Feature: test IndexAdd backward.
Description: test float32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.float32)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_float16():
"""
Feature: test IndexAdd backward.
Description: test float16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.float16)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_int32():
"""
Feature: test IndexAdd backward.
Description: test int32 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_int16():
"""
Feature: test IndexAdd backward.
Description: test int16 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.int16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.int16)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_int8():
"""
Feature: test IndexAdd backward.
Description: test int8 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.int8)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.int8)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_grad_uint8():
"""
Feature: test IndexAdd backward.
Description: test uint8 inputs.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
index_add_grad_with_type(np.uint8)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
index_add_grad_with_type(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_function():
"""
Feature: test IndexAdd function interface.
Description: test interface.
Expectation: the result match with numpy result
"""
context.set_context(device_target="CPU")
x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32), name="name_x")
indices = Tensor(np.array([0, 2]), mindspore.int32)
y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
output = ops.index_add(x, indices, y, 1)
expect = np.array([[1.5, 2, 4], [5, 5, 7.5], [9, 8, 11.5]])
np.testing.assert_array_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_index_add_dynamic():
"""
Feature: test IndexAdd dynamic shape.
Description: input y is dynamic shape.
Expectation: the result match with numpy result
"""
x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.float32)
y = np.ones((2, 2, 4), dtype=np.float32)
idx = np.array([0, 2]).astype(np.int32)
axis = 1
expect = np.copy(x)
expect[:, idx, :] = expect[:, idx, :] + y
y_dyn = Tensor(shape=[2, None, 4], dtype=mindspore.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
net = NetIndexAdd(x, axis)
net.set_inputs(Tensor(idx), y_dyn)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = NetIndexAdd(x, axis)
net.set_inputs(Tensor(idx), y_dyn)
output = net(Tensor(idx), Tensor(y))
assert (output.asnumpy() == expect).all()