forked from mindspore-Ecosystem/mindspore
InPlaceUpdateV2 CPU.
This commit is contained in:
parent
4cfe8bb5c3
commit
897f264ee3
|
@ -0,0 +1,214 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/in_place_op_v2_cpu_kernel.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "mindspore/core/ops/inplace_update_v2.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
struct Update {
|
||||
template <typename T>
|
||||
inline T operator()(const T &rhs) const {
|
||||
return rhs;
|
||||
}
|
||||
};
|
||||
template <typename Op>
|
||||
struct NoCheckUpdate {
|
||||
template <typename T>
|
||||
static inline void compute(T *x, const int64_t x_idx, const T *v, const int64_t v_idx) {
|
||||
x[x_idx] = Op()(v[v_idx]);
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
class InplaceOpV2CpuTypeFunc : public CpuKernelFunc {
|
||||
public:
|
||||
InplaceOpV2CpuTypeFunc() = default;
|
||||
~InplaceOpV2CpuTypeFunc() override = default;
|
||||
void InitFunc(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &,
|
||||
const std::vector<KernelTensorPtr> &) override {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(base_operator->GetPrim());
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
|
||||
static std::unordered_map<std::string, TypeComputeFuncV2> inplaceOpV2FuncMap = {
|
||||
{prim::kPrimInplaceUpdateV2->name(), &InplaceOpV2CpuTypeFunc<T>::InplaceOpV2<NoCheckUpdate<Update>>},
|
||||
};
|
||||
if (inplaceOpV2FuncMap.find(kernel_name_) == inplaceOpV2FuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'InplaceOpV2', only supports operators in "
|
||||
<< Map2Str<std::unordered_map, TypeComputeFuncV2>(inplaceOpV2FuncMap) << ", but got "
|
||||
<< kernel_name_ << ".";
|
||||
}
|
||||
compute_func_ = inplaceOpV2FuncMap.at(kernel_name_);
|
||||
}
|
||||
|
||||
int Resize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) override {
|
||||
if (inputs.size() != kInplaceOpV2InputNum) {
|
||||
MS_LOG(ERROR) << "For 'InplaceOpV2', the size of inputs must be 3, but got " << inputs.size() << ".";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
|
||||
MS_EXCEPTION_IF_NULL(inputs[kIndex2]);
|
||||
auto indice_shape = inputs[kIndex1]->GetShapeVector();
|
||||
auto v_shape = inputs[kIndex2]->GetShapeVector();
|
||||
MS_LOG(ERROR) << "For 'InplaceOpV2', the shape size of value:" << v_shape.size()
|
||||
<< " and indices:" << indice_shape.size() << " must not be 0.";
|
||||
if (indice_shape.empty() || v_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For 'InplaceOpV2', the shape size of value:" << v_shape.size()
|
||||
<< " and indices:" << indice_shape.size() << " must not be 0.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (v_shape[0] != indice_shape[0]) {
|
||||
MS_LOG(ERROR) << "For 'InplaceOpV2', the size of indices must equal to input_v's shape[0].";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
band_size_ = std::accumulate(v_shape.begin() + 1, v_shape.end(), int64_t(1), std::multiplies{});
|
||||
v_size_ = band_size_ * v_shape[0];
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void InplaceOpV2(T *x, const std::vector<int64_t> &indices, const T *v) {
|
||||
const int64_t band_size = band_size_;
|
||||
auto task = [band_size, indices, x, v](size_t start, size_t end) {
|
||||
int64_t start_long = SizeToLong(start);
|
||||
const int64_t end_long = SizeToLong(end);
|
||||
while (start_long < end_long) {
|
||||
const int64_t v_row = start_long / band_size;
|
||||
const int64_t x_row = (indices.data())[v_row];
|
||||
|
||||
int64_t offset = start_long % band_size;
|
||||
int64_t up_bound = (((v_row + 1) * band_size) > end_long) ? end_long % band_size : band_size;
|
||||
|
||||
int64_t x_offset = x_row * band_size;
|
||||
int64_t v_offset = v_row * band_size;
|
||||
for (int64_t j = offset; j < up_bound; ++j) {
|
||||
Op::compute(x, x_offset + j, v, v_offset + j);
|
||||
}
|
||||
start_long = v_row * band_size + up_bound;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, LongToSize(v_size_), this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
auto *x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *v = reinterpret_cast<T *>(inputs[kIndex2]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (memcpy_s(output, outputs[0]->size, x, inputs[0]->size) != EOK) {
|
||||
MS_LOG(ERROR) << "Function memcpy_s failed in 'InplaceOpV2'.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> indices;
|
||||
const auto *indice_ptr = reinterpret_cast<int *>(inputs[kIndex1]->addr);
|
||||
MS_EXCEPTION_IF_NULL(indice_ptr);
|
||||
for (size_t i = 0; i < inputs[kIndex1]->size / sizeof(int); ++i) {
|
||||
indices.emplace_back(IntToLong(indice_ptr[i]));
|
||||
}
|
||||
|
||||
compute_func_(this, output, indices, v);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string kernel_name_;
|
||||
int64_t band_size_{1};
|
||||
int64_t v_size_{1};
|
||||
|
||||
using TypeComputeFuncV2 =
|
||||
std::function<void(InplaceOpV2CpuTypeFunc *, T *x, const std::vector<int64_t> &indices, const T *v)>;
|
||||
TypeComputeFuncV2 compute_func_{nullptr};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<CpuKernelFunc> InplaceOpV2CpuFunc() {
|
||||
return std::make_shared<InplaceOpV2CpuTypeFunc<T>>();
|
||||
}
|
||||
using InplaceOpCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
|
||||
using OpFuncList = std::vector<std::pair<KernelAttr, InplaceOpCpuFuncCreator>>;
|
||||
|
||||
#define DTYPE_REGISTER(INPUT_X, INPUT_INDICES, INPUT_V, OUTPUT, T) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(INPUT_X).AddInputAttr(INPUT_INDICES).AddInputAttr(INPUT_V).AddOutputAttr(OUTPUT), \
|
||||
InplaceOpV2CpuFunc<T> \
|
||||
}
|
||||
|
||||
static const mindspore::HashMap<std::string, OpFuncList> kernel_attr_list = {
|
||||
{ops::kNameInplaceUpdateV2,
|
||||
{
|
||||
DTYPE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
DTYPE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat32, float),
|
||||
DTYPE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat16, float16),
|
||||
}},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool InPlaceOpV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "InplaceOpV2 does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
func_obj_ = kernel_attr_list.at(kernel_name_)[index].second();
|
||||
|
||||
func_obj_->InitFunc(base_operator, inputs, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int InPlaceOpV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
return func_obj_->Resize(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> InPlaceOpV2CpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_list.find(kernel_type_);
|
||||
if (iter == kernel_attr_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "InplaceOpV2 cpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, InplaceOpCpuFuncCreator> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, InplaceUpdateV2, InPlaceOpV2CpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IN_PLACE_OP_V2_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IN_PLACE_OP_V2_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kInplaceOpV2InputNum = 3;
|
||||
constexpr size_t kInplaceOpV2OutputNum = 1;
|
||||
class InPlaceOpV2CpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
InPlaceOpV2CpuKernelMod() = default;
|
||||
explicit InPlaceOpV2CpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~InPlaceOpV2CpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInplaceOpV2InputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kInplaceOpV2OutputNum, kernel_name_);
|
||||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<CpuKernelFunc> func_obj_;
|
||||
std::string kernel_type_{"Unknown"};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IN_PLACE_OP_V2_CPU_KERNEL_H_
|
|
@ -1838,7 +1838,7 @@ class InplaceUpdateV2(Primitive):
|
|||
TypeError: If `indices` is a tuple and its element is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend` ``GPU```
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
|
|
Loading…
Reference in New Issue