forked from mindspore-Ecosystem/mindspore
!3020 Add embedding look up kernels
Merge pull request !3020 from ZPaC/add-ps-embedding-look-up-kernel
This commit is contained in:
commit
da9452ee5e
|
@ -26,7 +26,10 @@ if (ENABLE_CPU)
|
|||
"cpu/*.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" "cpu/ps/pull_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc"
|
||||
"cpu/ps/pull_kernel.cc"
|
||||
"cpu/ps/embedding_look_up_ps_kernel.cc"
|
||||
"cpu/ps/embedding_look_up_proxy_kernel.cc")
|
||||
|
||||
if (NOT ENABLE_MPI)
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc")
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "kernel/cpu/ps/embedding_look_up_proxy_kernel.h"
|
||||
#include <vector>
|
||||
#include "parallel/ps/worker.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace ps {
|
||||
void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
EmbeddingLookUpCPUKernel::InitKernel(kernel_node);
|
||||
|
||||
for (auto dim : input_shape_) {
|
||||
input_dims_ *= dim;
|
||||
}
|
||||
|
||||
if (mindspore::parallel::ps::Util::IsRoleOfWorker()) {
|
||||
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
|
||||
}
|
||||
std::vector<size_t> keys{key_, key_, key_};
|
||||
std::vector<size_t> values;
|
||||
values.insert(values.end(), input_shape_.begin(), input_shape_.end());
|
||||
values.insert(values.end(), indices_shape_.begin(), indices_shape_.end());
|
||||
values.insert(values.end(), output_shape_.begin(), output_shape_.end());
|
||||
std::vector<int> lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()),
|
||||
SizeToInt(output_shape_.size())};
|
||||
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
|
||||
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
|
||||
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]);
|
||||
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
|
||||
}
|
||||
}
|
||||
|
||||
bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
size_t input_size = inputs[1]->size;
|
||||
size_t output_size = outputs[0]->size;
|
||||
|
||||
size_t size = input_size / sizeof(float);
|
||||
::ps::SArray<float> lookup_ids(size, 0);
|
||||
::ps::SArray<int> lengths{size};
|
||||
::ps::SArray<float> lookup_result;
|
||||
|
||||
auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
}
|
||||
parallel::ps::Worker<float>::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result,
|
||||
parallel::ps::kEmbeddingLookupCmd);
|
||||
|
||||
auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size);
|
||||
if (ret2 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup result memcpy failed.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
|
||||
|
||||
#include "kernel/cpu/embedding_look_up_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include "kernel/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace ps {
|
||||
class EmbeddingLookUpProxyKernel : public EmbeddingLookUpCPUKernel {
|
||||
public:
|
||||
EmbeddingLookUpProxyKernel() = default;
|
||||
~EmbeddingLookUpProxyKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t key_{0};
|
||||
size_t input_dims_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
EmbeddingLookupProxy,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EmbeddingLookUpProxyKernel);
|
||||
} // namespace ps
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/cpu/ps/embedding_look_up_ps_kernel.h"
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "parallel/ps/util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace ps {
|
||||
using mindspore::parallel::ps::Util;
|
||||
void EmbeddingLookUpPSKernel::InitKernel(
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
||||
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
|
||||
input_shape_ = *(shape_vec[0]);
|
||||
input_lens_ = 1;
|
||||
for (auto shape : input_shape_) {
|
||||
input_lens_ = input_lens_ * shape;
|
||||
}
|
||||
indices_shape_ = *(shape_vec[1]);
|
||||
indices_lens_ = 1;
|
||||
for (auto shape : indices_shape_) {
|
||||
indices_lens_ = indices_lens_ * shape;
|
||||
}
|
||||
output_shape_ = *(shape_vec[2]);
|
||||
axis_ = 2;
|
||||
reduce_scatter_flag_ = false;
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < rank_id_; i++) {
|
||||
offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_);
|
||||
}
|
||||
offset_ = offset;
|
||||
split_num_ = pserver_num_;
|
||||
|
||||
// input shape should be sharded after computing offset_;
|
||||
Shard(input_shape_, axis_);
|
||||
|
||||
size_t output_size =
|
||||
std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies<size_t>());
|
||||
output_size_list_.emplace_back(output_size);
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
||||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
||||
}
|
||||
|
||||
void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
||||
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
|
||||
const auto &indices_shape_ = *(shape_vec[0]);
|
||||
indices_lens_ = indices_shape_[0];
|
||||
|
||||
size_t output_size = sizeof(float) * indices_lens_;
|
||||
for (size_t i = axis_ + 1; i < input_shape_.size(); i++) {
|
||||
output_size *= input_shape_[i];
|
||||
}
|
||||
output_size_list_.clear();
|
||||
output_size_list_.emplace_back(output_size);
|
||||
}
|
||||
|
||||
bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
return Launch(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; }
|
||||
|
||||
const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); }
|
||||
|
||||
const std::vector<size_t> &EmbeddingLookUpPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); }
|
||||
} // namespace ps
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "kernel/cpu/embedding_look_up_cpu_kernel.h"
|
||||
#include "kernel/cpu/ps/pserver_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace ps {
|
||||
class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel {
|
||||
public:
|
||||
EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {}
|
||||
~EmbeddingLookUpPSKernel() override = default;
|
||||
|
||||
void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override;
|
||||
void ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override;
|
||||
|
||||
bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
const std::vector<size_t> &input_sizes() const override;
|
||||
const std::vector<size_t> &output_sizes() const override;
|
||||
const std::vector<size_t> &workspace_sizes() const override;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
|
Loading…
Reference in New Issue