forked from mindspore-Ecosystem/mindspore
!34757 add inplace_update operator and vmaprule
Merge pull request !34757 from wangjun/inplace_ops_0523
This commit is contained in:
commit
05afdee563
|
@ -105,6 +105,9 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
mindspore.ops.floor
|
||||
mindspore.ops.floor_div
|
||||
mindspore.ops.floor_mod
|
||||
mindspore.ops.inplace_add
|
||||
mindspore.ops.inplace_sub
|
||||
mindspore.ops.inplace_update
|
||||
mindspore.ops.invert
|
||||
mindspore.ops.lerp
|
||||
mindspore.ops.log
|
||||
|
|
|
@ -530,6 +530,69 @@ mindspore.Tensor
|
|||
|
||||
初始化的Tensor。
|
||||
|
||||
.. py::method:: inplace_add(v, indices)
|
||||
|
||||
根据 `indices`,将 `v` 加到原Tensor中。
|
||||
|
||||
.. note::
|
||||
`indices` 只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **v** (Tensor) - 待加的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
||||
|
||||
.. py::method:: inplace_sub(v, indices)
|
||||
|
||||
根据 `indices`,将 `v` 从原Tensor中减掉。
|
||||
|
||||
.. note::
|
||||
`indices` 只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **v** (Tensor) - 待减去的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
||||
|
||||
.. py::method:: inplace_update(v, indices)
|
||||
|
||||
根据 `indices` 以 `v` 来更新Tensor中的值。
|
||||
|
||||
.. note::
|
||||
`indices` 只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **v** (Tensor) - 用来更新的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
||||
|
||||
.. py:method:: item(index=None)
|
||||
|
||||
获取Tensor中指定索引的元素。
|
||||
|
|
|
@ -21,3 +21,4 @@ mindspore.ops.InplaceUpdate
|
|||
|
||||
- **TypeError** - `indices` 不是int或Tuple。
|
||||
- **TypeError** - `indices` 为Tuple,而其包含的某一元素非int类型。
|
||||
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.inplace_add
|
||||
=========================
|
||||
|
||||
.. py::method:: inplace_add(x, v, indices)
|
||||
|
||||
根据`indices`,将 `x` 中的对应位置加上 `v` 。
|
||||
|
||||
.. note::
|
||||
`indices`只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 待更新的Tensor
|
||||
- **v** (Tensor) - 待加上的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.inplace_sub
|
||||
=========================
|
||||
|
||||
.. py::method:: inplace_sub(x, v, indices)
|
||||
|
||||
根据`indices`,将 `v` 从 `x` 中减掉。
|
||||
|
||||
.. note::
|
||||
`indices` 只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 待更新的Tensor
|
||||
- **v** (Tensor) - 待减去的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.inplace_update
|
||||
============================
|
||||
|
||||
.. py::method:: inplace_update(x, v, indices)
|
||||
|
||||
根据`indices`,将 `x` 中的某些值更新为 `v`。
|
||||
|
||||
.. note::
|
||||
`indices` 只能沿着最高轴进行索引。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 待更新的Tensor
|
||||
- **v** (Tensor) - 更新的值。
|
||||
- **indices** (Union[int, tuple]) - 待更新值在原Tensor中的索引。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,更新后的Tensor。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int或tuple。
|
||||
- **TypeError** - `indices` 是元组,但是其中的元素不是int。
|
|
@ -105,6 +105,9 @@ Element-by-Element Operations
|
|||
mindspore.ops.floor
|
||||
mindspore.ops.floor_div
|
||||
mindspore.ops.floor_mod
|
||||
mindspore.ops.inplace_add
|
||||
mindspore.ops.inplace_sub
|
||||
mindspore.ops.inplace_update
|
||||
mindspore.ops.invert
|
||||
mindspore.ops.lerp
|
||||
mindspore.ops.log
|
||||
|
|
|
@ -200,6 +200,9 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
{"copy", std::string("copy")}, // copy()
|
||||
{"inplace_update", std::string("inplace_update")}, // P.InplaceUpdate
|
||||
{"inplace_add", std::string("inplace_add")}, // P.InplaceAdd
|
||||
{"inplace_sub", std::string("inplace_sub")}, // P.InplaceSub
|
||||
{"lerp", std::string("lerp")}, // lerp()
|
||||
{"log_matrix_determinant", std::string("log_matrix_determinant")}, // log_matrix_determinant()
|
||||
{"matrix_determinant", std::string("matrix_determinant")}, // log_matrix_determinant()
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include "mindspore/core/ops/inplace_add.h"
|
||||
#include "mindspore/core/ops/inplace_sub.h"
|
||||
#include "mindspore/core/ops/inplace_update.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -37,6 +38,13 @@ struct Sub {
|
|||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
struct Update {
|
||||
template <typename T>
|
||||
inline T operator()(const T &lhs, const T &rhs) const {
|
||||
return rhs;
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
class InplaceOpCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
||||
public:
|
||||
|
@ -54,6 +62,9 @@ class InplaceOpCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
} else if (kernel_name_ == ops::kNameInplaceSub) {
|
||||
auto kernel_ptr = std::make_shared<ops::InplaceSub>(base_operator->GetPrim());
|
||||
indices_ = kernel_ptr->get_indices();
|
||||
} else if (kernel_name_ == ops::kNameInplaceUpdate) {
|
||||
auto kernel_ptr = std::make_shared<ops::InplaceUpdate>(base_operator->GetPrim());
|
||||
indices_ = kernel_ptr->get_indices();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "InplaceOp cpu does not support " << kernel_name_;
|
||||
}
|
||||
|
@ -112,6 +123,7 @@ class InplaceOpCpuTypeFunc : public DeprecatedCpuKernelFunc {
|
|||
static std::unordered_map<std::string, TypeComputeFunc> inplaceOpFuncMap = {
|
||||
{prim::kPrimInplaceAdd->name(), &InplaceOpCpuTypeFunc<T>::InplaceOp<Add>},
|
||||
{prim::kPrimInplaceSub->name(), &InplaceOpCpuTypeFunc<T>::InplaceOp<Sub>},
|
||||
{prim::kPrimInplaceUpdate->name(), &InplaceOpCpuTypeFunc<T>::InplaceOp<Update>},
|
||||
};
|
||||
if (inplaceOpFuncMap.find(kernel_name_) == inplaceOpFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'InplaceOp', only supports operators in " << Unorderedmap2Str(inplaceOpFuncMap)
|
||||
|
@ -158,6 +170,15 @@ static const mindspore::HashMap<std::string, OpFuncList> kernel_attr_list = {
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
InplaceOpCpuFunc<float16>},
|
||||
}},
|
||||
{ops::kNameInplaceUpdate,
|
||||
{
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
InplaceOpCpuFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
InplaceOpCpuFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
InplaceOpCpuFunc<float16>},
|
||||
}},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -193,5 +214,6 @@ std::vector<KernelAttr> InPlaceOpCpuKernelMod::GetOpSupport() {
|
|||
|
||||
MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, InplaceAdd, InPlaceOpCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, InplaceSub, InPlaceOpCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, InplaceUpdate, InPlaceOpCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -805,6 +805,7 @@ GVAR_DEF(PrimitivePtr, kPrimCumSum, std::make_shared<Primitive>("CumSum"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimCumProd, std::make_shared<Primitive>("CumProd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSubscalar, std::make_shared<Primitive>("Subscalar"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimInplaceAdd, std::make_shared<Primitive>("InplaceAdd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimInplaceUpdate, std::make_shared<Primitive>("InplaceUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLpNorm, std::make_shared<Primitive>(kLpNorm));
|
||||
GVAR_DEF(PrimitivePtr, kPrimInplaceSub, std::make_shared<Primitive>("InplaceSub"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPow, std::make_shared<Primitive>("Pow"));
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "ops/inplace_add.h"
|
||||
#include "ops/inplace_sub.h"
|
||||
#include "ops/inplace_update.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
@ -88,6 +89,7 @@ TypePtr InplaceOpInferType(const PrimitivePtr &prim, const std::vector<AbstractB
|
|||
} // namespace
|
||||
void InplaceAdd::set_indices(std::vector<int64_t> indices) { AddAttr(kIndices, api::MakeValue(indices)); }
|
||||
void InplaceSub::set_indices(std::vector<int64_t> indices) { AddAttr(kIndices, api::MakeValue(indices)); }
|
||||
void InplaceUpdate::set_indices(std::vector<int64_t> indices) { AddAttr(kIndices, api::MakeValue(indices)); }
|
||||
|
||||
std::vector<int64_t> InplaceAdd::get_indices() const {
|
||||
auto value_ptr = GetAttr(kIndices);
|
||||
|
@ -107,8 +109,18 @@ std::vector<int64_t> InplaceSub::get_indices() const {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> InplaceUpdate::get_indices() const {
|
||||
auto value_ptr = GetAttr(kIndices);
|
||||
if (value_ptr->isa<mindspore::api::ValueSequence>()) {
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
} else {
|
||||
return {GetValue<int64_t>(value_ptr)};
|
||||
}
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(InplaceAdd, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(InplaceSub, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(InplaceUpdate, BaseOperator);
|
||||
AbstractBasePtr InplaceOpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto dtype = InplaceOpInferType(primitive, input_args);
|
||||
|
@ -117,5 +129,6 @@ AbstractBasePtr InplaceOpInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(InplaceAdd, prim::kPrimInplaceAdd, InplaceOpInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(InplaceSub, prim::kPrimInplaceSub, InplaceOpInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(InplaceUpdate, prim::kPrimInplaceUpdate, InplaceOpInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_INPLACE_UPDATE_H_
|
||||
#define MINDSPORE_CORE_OPS_INPLACE_UPDATE_H_
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameInplaceUpdate = "InplaceUpdate";
|
||||
/// \brief InplaceUpdate operation. Refer to Python API @ref mindspore.ops.InplaceUpdate for more details.
|
||||
class MIND_API InplaceUpdate : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(InplaceUpdate);
|
||||
/// \brief Constructor.
|
||||
InplaceUpdate() : BaseOperator(kNameInplaceUpdate) { InitIOName({"x", "v"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.InplaceUpdate for the inputs.
|
||||
void Init(std::vector<int64_t> indices) { set_indices(indices); }
|
||||
/// \brief Set indices.
|
||||
void set_indices(std::vector<int64_t> indices);
|
||||
/// \brief Get indices.
|
||||
///
|
||||
/// \return indices.
|
||||
std::vector<int64_t> get_indices() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr InplaceUpdateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimInplaceUpdatePtr = std::shared_ptr<InplaceUpdate>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_INPLACE_UPDATE_H_
|
|
@ -1803,6 +1803,27 @@ def masked_select(x, mask):
|
|||
return F.masked_select(x, mask)
|
||||
|
||||
|
||||
def inplace_update(x, v, indices):
|
||||
"""
|
||||
Update specified rows of x with values in v according to indices.
|
||||
"""
|
||||
return F.inplace_update(x, v, indices)
|
||||
|
||||
|
||||
def inplace_add(x, v, indices):
|
||||
"""
|
||||
Add v into specified rows of x according to indices.
|
||||
"""
|
||||
return F.inplace_add(x, v, indices)
|
||||
|
||||
|
||||
def inplace_sub(x, v, indices):
|
||||
"""
|
||||
Subtract v from specified rows of x according to indices.
|
||||
"""
|
||||
return F.inplace_sub(x, v, indices)
|
||||
|
||||
|
||||
def coo_to_csr(x):
|
||||
"""convert coo to csr."""
|
||||
row_indices = x.indices[:, 0]
|
||||
|
|
|
@ -1979,6 +1979,98 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('cumsum')()(x, axis).astype(dtype, copy=False)
|
||||
return tensor_operator_registry.get('cumsum')()(x, axis)
|
||||
|
||||
def inplace_update(self, v, indices):
|
||||
"""
|
||||
Update some rows of a tensor with values of v according to the specified indices.
|
||||
|
||||
Args:
|
||||
v (Tensor): A tensor with the same type and same dimension size except the first dimension, which must be
|
||||
the same as the size of indices.
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension determining which rows to be updated.
|
||||
|
||||
Returns:
|
||||
Tensor, with updated values.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.1, 0.2], [0.3, 0.4]]), mindspore.float32)
|
||||
>>> indices = (0, 1)
|
||||
>>> output = x.inplace_update(v, indices)
|
||||
>>> print(output)
|
||||
[[0.1 0.2]
|
||||
[0.3 0.4]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('inplace_update')(indices)(self, v)
|
||||
|
||||
def inplace_add(self, v, indices):
|
||||
"""
|
||||
Add v into specified rows of a tensor according to indices.
|
||||
|
||||
Args:
|
||||
v (Tensor): A tensor with the same type and same dimension size except the first dimension, which must be
|
||||
the same as the size of indices.
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension determining which rows to be added.
|
||||
|
||||
Returns:
|
||||
Tensor, with values after adding.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.1, 0.2], [0.3, 0.4]]), mindspore.float32)
|
||||
>>> indices = (0, 1)
|
||||
>>> output = x.inplace_add(v, indices)
|
||||
>>> print(output)
|
||||
[[1.1 2.2]
|
||||
[3.3 4.4]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('inplace_add')(indices)(self, v)
|
||||
|
||||
def inplace_sub(self, v, indices):
|
||||
"""
|
||||
Subtract v into specified rows of a tensor according to indices.
|
||||
Args:
|
||||
v (Tensor): A tensor with the same type and same dimension size except the first dimension, which must be
|
||||
the same as the size of indices.
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension determining which rows to be subtracted.
|
||||
|
||||
Returns:
|
||||
Tensor, with values after subtracting.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.1, 0.2], [0.3, 0.4]]), mindspore.float32)
|
||||
>>> indices = (0, 1)
|
||||
>>> output = x.inplace_sub(v, indices)
|
||||
>>> print(output)
|
||||
[[0.9 1.8]
|
||||
[2.7 3.6]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('inplace_sub')(indices)(self, v)
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Return a copy of the tensor.
|
||||
|
|
|
@ -241,6 +241,32 @@ def get_broadcast_to_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.InplaceAdd)
|
||||
@vmap_rules_getters.register(P.InplaceSub)
|
||||
@vmap_rules_getters.register(P.InplaceUpdate)
|
||||
def get_inplace_ops_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `InplaceAdd`, `InplaceSub`, `InplaceUpdate` operation."""
|
||||
|
||||
def vmap_rule(x_bdim, v_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, v_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
v, v_dim = v_bdim
|
||||
if x_dim is None:
|
||||
x = _broadcast_by_axis(x, -1, axis_size)
|
||||
else:
|
||||
x = mnp.moveaxis(x, x_dim, -1)
|
||||
if v_dim is None:
|
||||
v = _broadcast_by_axis(v, -1, axis_size)
|
||||
else:
|
||||
v = mnp.moveaxis(v, v_dim, -1)
|
||||
out = prim(x, v)
|
||||
return (out, out.ndim - 1)
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_reduce_batch_axis(axis, x_dim, x_ndim):
|
||||
"""get batch_axis for reduce* operation."""
|
||||
|
|
|
@ -143,6 +143,9 @@ from .math_func import (
|
|||
logaddexp,
|
||||
logaddexp2,
|
||||
mv,
|
||||
inplace_add,
|
||||
inplace_sub,
|
||||
inplace_update,
|
||||
inv,
|
||||
invert,
|
||||
minimum,
|
||||
|
|
|
@ -730,6 +730,119 @@ def floor(x):
|
|||
return floor_(x)
|
||||
|
||||
|
||||
def inplace_update(x, v, indices):
|
||||
"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of x
|
||||
to update with v. It is an int or tuple, whose value is in [0, the first dimension size of x).
|
||||
x (Tensor) - A tensor which to be inplace updated. It can be one of the following data types:
|
||||
float32, float16 and int32.
|
||||
v (Tensor) - A tensor with the same type as `x` and the same dimension size as `x` except
|
||||
the first dimension, which must be the same as the size of `indices`.
|
||||
|
||||
Returns:
|
||||
Tensor, with the same type and shape as the input `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is neither int nor tuple.
|
||||
TypeError: If `indices` is a tuple and its element is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplace_update = ops.InplaceUpdate(indices)
|
||||
>>> output = inplace_update(x, v)
|
||||
>>> print(output)
|
||||
[[0.5 1. ]
|
||||
[1. 1.5]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
inplace_update_inner = P.InplaceUpdate(indices)
|
||||
return inplace_update_inner(x, v)
|
||||
|
||||
|
||||
def inplace_add(x, v, indices):
|
||||
"""
|
||||
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of `x`
|
||||
to add with `v`. It is an integer or a tuple, whose value is in [0, the first dimension size of `x`).
|
||||
x (Tensor) - The first input is a tensor whose data type is float16, float32 or int32.
|
||||
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
|
||||
v (Tensor) - The second input is a tensor that has the same dimension sizes as `x` except
|
||||
the first dimension, which must be the same as indices' size. It has the same data type with `x`.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is neither int nor tuple.
|
||||
TypeError: If `indices` is a tuple whose elements are not all int.
|
||||
ValueError: If length of shape of `x` is not equal to length of shape of `input_v`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplaceAdd = ops.InplaceAdd(indices)
|
||||
>>> output = inplaceAdd(x, input_v)
|
||||
>>> print(output)
|
||||
[[1.5 3. ]
|
||||
[4. 5.5]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
inplace_add_inner = P.InplaceAdd(indices)
|
||||
return inplace_add_inner(x, v)
|
||||
|
||||
|
||||
def inplace_sub(x, v, indices):
|
||||
"""
|
||||
Subtracts `v` into specified rows of `x`. Computes `y` = `x`; y[i,] -= `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of `x`
|
||||
to subtract with `v`. It is an int or tuple, whose value is in [0, the first dimension size of `x`).
|
||||
x (Tensor) - The first input is a tensor whose data type is float16, float32 or int32.
|
||||
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
|
||||
v (Tensor) - The second input is a tensor who has the same dimension sizes as `x` except
|
||||
the first dimension, which must be the same as indices' size. It has the same data type with `x`.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is neither int nor tuple.
|
||||
TypeError: If `indices` is a tuple whose elements are not all int.
|
||||
ValueError: If length of shape of `x` is not equal to length of shape of `input_v`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplaceSub = ops.InplaceSub(indices)
|
||||
>>> output = inplaceSub(x, input_v)
|
||||
>>> print(output)
|
||||
[[0.5 1. ]
|
||||
[2. 2.5]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
inplace_sub_inner = P.InplaceSub(indices)
|
||||
return inplace_sub_inner(x, v)
|
||||
|
||||
|
||||
def logical_not(x):
|
||||
"""
|
||||
Computes the "logical NOT" of a tensor element-wise.
|
||||
|
@ -2894,6 +3007,9 @@ __all__ = [
|
|||
'equal',
|
||||
'not_equal',
|
||||
'ne',
|
||||
'inplace_update',
|
||||
'inplace_add',
|
||||
'inplace_sub',
|
||||
'isfinite',
|
||||
'isnan',
|
||||
'isreal',
|
||||
|
|
|
@ -953,6 +953,9 @@ tensor_operator_registry.register('svd', linalg_ops.Svd)
|
|||
tensor_operator_registry.register('diag', P.Diag)
|
||||
tensor_operator_registry.register('unique_consecutive', UniqueConsecutive)
|
||||
tensor_operator_registry.register('pdist', NN.Pdist)
|
||||
tensor_operator_registry.register('inplace_update', P.InplaceUpdate)
|
||||
tensor_operator_registry.register('inplace_add', P.InplaceAdd)
|
||||
tensor_operator_registry.register('inplace_sub', P.InplaceSub)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -43,7 +43,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TupleToArray,
|
||||
UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
|
||||
TensorScatterMul, TensorScatterDiv,
|
||||
|
@ -63,7 +63,7 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
|
|||
FusedAdaFactorWithGlobalNorm)
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||
BitwiseAnd, BitwiseOr, Ger,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, InplaceUpdate,
|
||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
|
||||
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
|
||||
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||
|
|
|
@ -5885,73 +5885,6 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class InplaceUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of x
|
||||
to update with v. It is an int or tuple, whose value is in [0, the first dimension size of x).
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor which to be inplace updated. It can be one of the following data types:
|
||||
float32, float16 and int32.
|
||||
- **v** (Tensor) - A tensor with the same type as `x` and the same dimension size as `x` except
|
||||
the first dimension, which must be the same as the size of `indices`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the input `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is neither int nor tuple.
|
||||
TypeError: If `indices` is a tuple and its element is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplace_update = ops.InplaceUpdate(indices)
|
||||
>>> output = inplace_update(x, v)
|
||||
>>> print(output)
|
||||
[[0.5 1. ]
|
||||
[1. 1.5]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, indices):
|
||||
"""Initialize InplaceUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
|
||||
self.indices = indices
|
||||
validator.check_value_type("indices", indices, [int, tuple], self.name)
|
||||
if isinstance(indices, int):
|
||||
self.indices = (indices,)
|
||||
for item in self.indices:
|
||||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype, v_dtype):
|
||||
args = {'x': x_dtype, 'v': v_dtype}
|
||||
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape, v_shape):
|
||||
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
||||
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
||||
Rel.EQ, self.name)
|
||||
for i in self.indices:
|
||||
if i < 0 or i >= x_shape[0]:
|
||||
raise ValueError(f"For '{self.name}', the value of indices must be in [0, {x_shape[0]}), "
|
||||
f"but got {i}.")
|
||||
x_rank = len(x_shape)
|
||||
for idx in range(x_rank)[1:]:
|
||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
|
||||
class ReverseSequence(PrimitiveWithInfer):
|
||||
"""
|
||||
Reverses variable length slices.
|
||||
|
|
|
@ -1606,6 +1606,73 @@ class Neg(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class InplaceUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of x
|
||||
to update with v. It is an int or tuple, whose value is in [0, the first dimension size of x).
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor which to be inplace updated. It can be one of the following data types:
|
||||
float32, float16 and int32.
|
||||
- **v** (Tensor) - A tensor with the same type as `x` and the same dimension size as `x` except
|
||||
the first dimension, which must be the same as the size of `indices`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the input `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is neither int nor tuple.
|
||||
TypeError: If `indices` is a tuple and its element is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplace_update = ops.InplaceUpdate(indices)
|
||||
>>> output = inplace_update(x, v)
|
||||
>>> print(output)
|
||||
[[0.5 1. ]
|
||||
[1. 1.5]
|
||||
[5. 6. ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, indices):
|
||||
"""Initialize InplaceUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
|
||||
self.indices = indices
|
||||
validator.check_value_type("indices", indices, [int, tuple], self.name)
|
||||
if isinstance(indices, int):
|
||||
self.indices = (indices,)
|
||||
for item in self.indices:
|
||||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype, v_dtype):
|
||||
args = {'x': x_dtype, 'v': v_dtype}
|
||||
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape, v_shape):
|
||||
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
||||
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
||||
Rel.EQ, self.name)
|
||||
for i in self.indices:
|
||||
if i < 0 or i >= x_shape[0]:
|
||||
raise ValueError(f"For '{self.name}', the value of indices must be in [0, {x_shape[0]}), "
|
||||
f"but got {i}.")
|
||||
x_rank = len(x_shape)
|
||||
for idx in range(x_rank)[1:]:
|
||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
|
||||
class InplaceAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
|
||||
|
|
|
@ -17,6 +17,17 @@ import pytest
|
|||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class InplaceOps(nn.Cell):
|
||||
def __init__(self, indices):
|
||||
super(InplaceOps, self).__init__()
|
||||
self.indices = indices
|
||||
|
||||
def construct(self, x, v):
|
||||
return x.inplace_update(v, self.indices)
|
||||
|
||||
|
||||
def inplace_op_np(op, x: np.ndarray, v: np.ndarray, indices):
|
||||
|
@ -27,6 +38,8 @@ def inplace_op_np(op, x: np.ndarray, v: np.ndarray, indices):
|
|||
result[indices, :] += v
|
||||
elif op == 'sub':
|
||||
result[indices, :] -= v
|
||||
elif op == 'update':
|
||||
result[indices, :] = v
|
||||
return result
|
||||
|
||||
|
||||
|
@ -112,3 +125,50 @@ def test_inplace_sub_1d(shape, indice, dtype):
|
|||
result = P.InplaceSub(indice)(Tensor(x), Tensor(v))
|
||||
expected = inplace_op_np('sub', x, v, indice)
|
||||
np.allclose(result.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape, indice_len', [((10, 4, 3, 2), 4), ((5, 2, 4, 6), 3)])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float16, np.int32])
|
||||
def test_inplace_update(shape, indice_len, dtype):
|
||||
"""
|
||||
Feature: test InplaceUpdate
|
||||
Description: test InplaceUpate
|
||||
Expectation: result is the same as expected
|
||||
"""
|
||||
context.set_context(device_target='CPU')
|
||||
x = np.random.random(shape).astype(dtype)
|
||||
v = np.random.random((indice_len,) + shape[1:]).astype(dtype)
|
||||
indices = np.random.choice(list(range(shape[0])), indice_len, replace=False)
|
||||
indices = tuple((int(i) for i in indices))
|
||||
|
||||
result = P.InplaceUpdate(indices)(Tensor(x), Tensor(v))
|
||||
expected = inplace_op_np('update', x, v, indices)
|
||||
np.allclose(result.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape, indice_len', [((10, 4, 3, 2), 2)])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float16, np.int32])
|
||||
def test_vmap_inplace_ops(shape, indice_len, dtype):
|
||||
"""
|
||||
Feature: test vmap inplace operators
|
||||
Description: test vmap inplace operators
|
||||
Expectation: result is the same as expected
|
||||
"""
|
||||
context.set_context(device_target='CPU')
|
||||
x = np.random.random(shape).astype(dtype)
|
||||
v = np.random.random((indice_len,) + shape[2:]).astype(dtype)
|
||||
indices = np.random.choice(list(range(shape[1])), indice_len, replace=False)
|
||||
indices = tuple((int(i) for i in indices))
|
||||
|
||||
inplace_op = InplaceOps(indices)
|
||||
result = vmap(inplace_op, in_axes=(0, None), out_axes=0)(Tensor(x), Tensor(v))
|
||||
expected = np.zeros(shape=shape)
|
||||
for i in range(shape[0]):
|
||||
expected[i] = inplace_op_np('update', x[i], v, indices)
|
||||
np.allclose(result.asnumpy(), expected)
|
||||
|
|
Loading…
Reference in New Issue