Performance optimization of embedding_lookup

This commit is contained in:
yujianfeng 2020-07-24 09:48:43 +08:00
parent 657b547116
commit 57cb1eeb14
11 changed files with 189 additions and 230 deletions

View File

@ -17,143 +17,21 @@
#include <string>
#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/mpi/mpi_adapter.h"
#include "ir/primitive.h"
namespace mindspore {
namespace kernel {
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_lens_ = 1;
for (auto shape : input_shape_) {
input_lens_ = input_lens_ * shape;
}
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
indices_lens_ = 1;
for (auto shape : indices_shape_) {
indices_lens_ = indices_lens_ * shape;
}
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = 4 - input_shape_.size();
if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) {
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrReduceScatterFlag);
}
#ifdef ENABLE_MPI
if (reduce_scatter_flag_) {
size_t gatherv2_out_lens = 1;
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
if (i == 0) {
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
}
} else {
gatherv2_out_lens = gatherv2_out_lens * input_shape_[i];
}
}
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float);
gather_v2_out_ = malloc(gatherv2_out_lens_);
if (gather_v2_out_ == nullptr) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
}
auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
}
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
}
#else
if (reduce_scatter_flag_) {
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
}
#endif
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
}
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1];
size_t dim2 = input_shape_[2];
if (axis_ == 3) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
LookUpTable(inputs, i, j, k, &gather_out_addr);
}
}
}
} else if (axis_ == 2) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
LookUpTable(inputs, i, j, 0, &gather_out_addr);
}
}
} else if (axis_ == 1) {
for (size_t i = 0; i < dim0; ++i) {
LookUpTable(inputs, i, 0, 0, &gather_out_addr);
}
} else if (axis_ == 0) {
LookUpTable(inputs, 0, 0, 0, &gather_out_addr);
}
#ifdef ENABLE_MPI
if (reduce_scatter_flag_) {
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
size_t reduce_scatter_out_lens = one_split_lens / 8;
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) {
mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
}
}
#endif
return true;
}
void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens,
size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis,
std::vector<size_t> input_shape, size_t input_lens) {
size_t lens = num * sizeof(float);
namespace {
void LookUpTableTask(const float *input_addr, const int *indices_addr, float *output_addr, size_t indices_lens,
size_t outer_dim_size, int offset, size_t first_dim_size) {
size_t lens = outer_dim_size * sizeof(float);
for (size_t i = 0; i < indices_lens; ++i) {
int indices = indices_addr[i] - offset;
if (indices >= 0) {
size_t index = IntToSize(indices);
if (index < input_shape[axis]) {
size_t pos = 0;
if (axis == 3) {
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index);
} else if (axis == 2) {
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0);
} else if (axis == 1) {
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0);
} else if (axis == 0) {
pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0);
}
if (pos + num <= input_lens) {
auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
}
} else {
auto ret = memset_s(output_addr, lens, 0, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
}
}
} else {
auto ret = memset_s(output_addr, lens, 0, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
}
int index = indices_addr[i] - offset;
if (index >= 0 && index < SizeToInt(first_dim_size)) {
size_t pos = index * outer_dim_size;
auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
}
} else {
auto ret = memset_s(output_addr, lens, 0, lens);
@ -161,16 +39,36 @@ void LookUpTable_task(const float *input_addr, float *output_addr, const int *in
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
}
}
output_addr += num;
output_addr += outer_dim_size;
}
}
} // namespace
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "param must be at least 1D";
}
first_dim_size_ = input_shape[0];
for (size_t i = 1; i < input_shape.size(); ++i) {
outer_dim_size_ *= input_shape[i];
}
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (const auto &shape : indices_shape) {
indices_lens_ *= shape;
}
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
}
}
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
size_t dim2, float **output_addr) {
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_);
float *task_out_addr = *output_addr;
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
const size_t thread_num = 8;
std::thread threads[8];
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
@ -183,8 +81,8 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr>
}
MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens;
threads[i] =
std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset,
task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_);
std::thread(LookUpTableTask, input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_,
task_proc_lens, outer_dim_size_, offset_, first_dim_size_);
task_offset += task_proc_lens;
if (task_offset + task_proc_lens > indices_lens_) {
task_proc_lens = indices_lens_ - task_offset;
@ -193,14 +91,14 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr>
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
*output_addr += num * indices_lens_;
return true;
}
void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size()
<< ", but EmbeddingLookUpCPUKernel olny support 4d or lower.";
<< ", but EmbeddingLookUpCPUKernel only support 4d or lower.";
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);

View File

@ -24,22 +24,8 @@ namespace mindspore {
namespace kernel {
class EmbeddingLookUpCPUKernel : public CPUKernel {
public:
EmbeddingLookUpCPUKernel() {
axis_ = 0;
offset_ = 0;
split_num_ = 0;
input_lens_ = 0;
indices_lens_ = 0;
gatherv2_out_lens_ = 0;
reduce_scatter_flag_ = false;
gather_v2_out_ = nullptr;
}
~EmbeddingLookUpCPUKernel() override {
if (gather_v2_out_ != nullptr) {
free(gather_v2_out_);
gather_v2_out_ = nullptr;
}
}
EmbeddingLookUpCPUKernel() {}
~EmbeddingLookUpCPUKernel() override {}
void InitKernel(const CNodePtr &kernel_node) override;
@ -47,21 +33,11 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
protected:
void LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
float **output_addr);
void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
int axis_;
int offset_;
int split_num_;
size_t input_lens_;
size_t indices_lens_;
size_t gatherv2_out_lens_;
bool reduce_scatter_flag_;
void *gather_v2_out_;
int offset_{0};
size_t indices_lens_{1};
size_t first_dim_size_{1};
size_t outer_dim_size_{1};
};
MS_REG_CPU_KERNEL(

View File

@ -22,8 +22,13 @@ namespace kernel {
namespace ps {
void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
EmbeddingLookUpCPUKernel::InitKernel(kernel_node);
for (auto dim : input_shape_) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
size_t axis = kShape4dDims - input_shape.size();
CPUKernelUtils::ExpandDimsTo4(&input_shape);
CPUKernelUtils::ExpandDimsTo4(&output_shape);
for (auto dim : input_shape) {
input_dims_ *= dim;
}
@ -32,14 +37,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
}
std::vector<size_t> keys{key_, key_, key_};
std::vector<size_t> values;
values.insert(values.end(), input_shape_.begin(), input_shape_.end());
values.insert(values.end(), indices_shape_.begin(), indices_shape_.end());
values.insert(values.end(), output_shape_.begin(), output_shape_.end());
std::vector<int> lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()),
SizeToInt(output_shape_.size())};
values.insert(values.end(), input_shape.begin(), input_shape.end());
values.insert(values.end(), indices_shape.begin(), indices_shape.end());
values.insert(values.end(), output_shape.begin(), output_shape.end());
std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())};
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]);
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape[axis]);
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
}
}

View File

@ -25,47 +25,40 @@ namespace mindspore {
namespace kernel {
namespace ps {
using mindspore::parallel::ps::Util;
constexpr int kAxis = 2;
void EmbeddingLookUpPSKernel::InitKernel(
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
input_shape_ = *(shape_vec[0]);
input_lens_ = 1;
for (auto shape : input_shape_) {
input_lens_ = input_lens_ * shape;
}
indices_shape_ = *(shape_vec[1]);
auto indices_shape = *(shape_vec[1]);
indices_lens_ = 1;
for (auto shape : indices_shape_) {
for (auto shape : indices_shape) {
indices_lens_ = indices_lens_ * shape;
}
output_shape_ = *(shape_vec[2]);
axis_ = 2;
reduce_scatter_flag_ = false;
auto output_shape = *(shape_vec[2]);
size_t offset = 0;
for (size_t i = 0; i < rank_id_; i++) {
offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_);
offset += Util::LocalShard(input_shape_[kAxis], i, pserver_num_);
}
offset_ = offset;
split_num_ = pserver_num_;
// input shape should be sharded after computing offset_;
Shard(&input_shape_, axis_);
Shard(&input_shape_, kAxis);
size_t output_size =
std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies<size_t>());
std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies<size_t>());
output_size_list_.emplace_back(output_size);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
const auto &indices_shape_ = *(shape_vec[0]);
indices_lens_ = indices_shape_[0];
const auto &indices_shape = *(shape_vec[0]);
indices_lens_ = indices_shape[0];
size_t output_size = sizeof(float) * indices_lens_;
for (size_t i = axis_ + 1; i < input_shape_.size(); i++) {
for (size_t i = kAxis + 1; i < input_shape_.size(); i++) {
output_size *= input_shape_[i];
}
output_size_list_.clear();

View File

@ -38,6 +38,9 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK
const std::vector<size_t> &input_sizes() const override;
const std::vector<size_t> &output_sizes() const override;
const std::vector<size_t> &workspace_sizes() const override;
private:
std::vector<size_t> input_shape_;
};
} // namespace ps
} // namespace kernel

View File

@ -27,12 +27,12 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en
auto m = input_params->m_;
auto m_t = input_params->m_t_;
auto v = input_params->v_;
auto beta1 = input_params->beta1_;
auto beta2 = input_params->beta2_;
auto use_nesterov = input_params->use_nesterov_;
auto unique_sparse_grad = input_params->sparse_grad_;
auto var_first_dim_size = input_params->var_first_dim_size_;
auto var_outer_dim_size = input_params->var_outer_dim_size_;
const auto beta1 = input_params->beta1_;
const auto beta2 = input_params->beta2_;
const auto use_nesterov = input_params->use_nesterov_;
const auto unique_sparse_grad = input_params->sparse_grad_;
const auto var_first_dim_size = input_params->var_first_dim_size_;
const auto var_outer_dim_size = input_params->var_outer_dim_size_;
for (size_t i = start; i < end; ++i) {
int index = unique_sparse_grad.indices_[i];
if (index < 0 || IntToSize(index) >= var_first_dim_size) {
@ -55,8 +55,8 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_
MS_EXCEPTION_IF_NULL(input_params);
auto m = input_params->m_;
auto v = input_params->v_;
auto beta1 = input_params->beta1_;
auto beta2 = input_params->beta2_;
const auto beta1 = input_params->beta1_;
const auto beta2 = input_params->beta2_;
for (size_t i = start; i < end; ++i) {
m[i] *= beta1;
v[i] *= beta2;
@ -66,10 +66,10 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_
void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) {
MS_EXCEPTION_IF_NULL(input_params);
auto var = input_params->var_;
auto m = input_params->m_;
auto v = input_params->v_;
auto lr = input_params->lr_;
auto epsilon = input_params->epsilon_;
const auto *m = input_params->m_;
const auto *v = input_params->v_;
const auto lr = input_params->lr_;
const auto epsilon = input_params->epsilon_;
for (size_t i = start; i < end; ++i) {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
}

View File

@ -27,13 +27,13 @@ void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t en
auto var = input_params->var_;
auto accum = input_params->accum_;
auto linear = input_params->linear_;
auto lr = input_params->lr_;
auto l1 = input_params->l1_;
auto l2_plus = 2 * input_params->l2_;
auto lr_power = input_params->lr_power_;
auto unique_sparse_grad = input_params->sparse_grad_;
auto var_first_dim_size = input_params->var_first_dim_size_;
auto var_outer_dim_size = input_params->var_outer_dim_size_;
const auto lr = input_params->lr_;
const auto l1 = input_params->l1_;
const auto l2_plus = 2 * input_params->l2_;
const auto lr_power = input_params->lr_power_;
const auto unique_sparse_grad = input_params->sparse_grad_;
const auto var_first_dim_size = input_params->var_first_dim_size_;
const auto var_outer_dim_size = input_params->var_outer_dim_size_;
for (size_t i = start; i < end; ++i) {
int index = unique_sparse_grad.indices_[i];
if (index < 0 || IntToSize(index) >= var_first_dim_size) {

View File

@ -27,14 +27,14 @@ void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_
auto var = input_params->var_;
auto m = input_params->m_;
auto v = input_params->v_;
auto lr = input_params->lr_;
auto beta1 = input_params->beta1_;
auto beta2 = input_params->beta2_;
auto epsilon = input_params->epsilon_;
auto use_nesterov = input_params->use_nesterov_;
auto unique_sparse_grad = input_params->sparse_grad_;
auto var_first_dim_size = input_params->var_first_dim_size_;
auto var_outer_dim_size = input_params->var_outer_dim_size_;
const auto lr = input_params->lr_;
const auto beta1 = input_params->beta1_;
const auto beta2 = input_params->beta2_;
const auto epsilon = input_params->epsilon_;
const auto use_nesterov = input_params->use_nesterov_;
const auto unique_sparse_grad = input_params->sparse_grad_;
const auto var_first_dim_size = input_params->var_first_dim_size_;
const auto var_outer_dim_size = input_params->var_outer_dim_size_;
for (size_t i = start; i < end; ++i) {
int index = unique_sparse_grad.indices_[i];
if (index < 0 || IntToSize(index) >= var_first_dim_size) {

View File

@ -26,12 +26,12 @@ void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start
MS_EXCEPTION_IF_NULL(input_params);
auto var = input_params->var_;
auto accum = input_params->accum_;
auto lr = input_params->lr_;
auto l1 = input_params->l1_;
auto l2 = input_params->l2_;
auto unique_sparse_grad = input_params->sparse_grad_;
auto var_first_dim_size = input_params->var_first_dim_size_;
auto var_outer_dim_size = input_params->var_outer_dim_size_;
const auto lr = input_params->lr_;
const auto l1 = input_params->l1_;
const auto l2 = input_params->l2_;
const auto unique_sparse_grad = input_params->sparse_grad_;
const auto var_first_dim_size = input_params->var_first_dim_size_;
const auto var_outer_dim_size = input_params->var_outer_dim_size_;
for (size_t i = start; i < end; ++i) {
int index = unique_sparse_grad.indices_[i];
if (index < 0 || IntToSize(index) >= var_first_dim_size) {

View File

@ -54,7 +54,7 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
} else if (i == IntToSize(axis)) {
new_shape.push_back(offset);
} else {
new_shape.push_back(output_shape[i - 1]);
new_shape.push_back(output_shape[SizeToInt(i) - 1]);
}
}
new_shape.erase(new_shape.begin() + axis + 1);

View File

@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, offset):
super(Net, self).__init__()
self.embedding = P.EmbeddingLookup()
self.offset = offset
def construct(self, param, index):
return self.embedding(param, index, self.offset)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_embedding_look_up0():
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32)
indices = Tensor(np.array([5, 2, 8, 5]), mstype.int32)
offset = 4
embedding = Net(offset)
out = embedding(params, indices)
expect = np.array([[10, 11], [0, 0], [0, 0], [10, 11]]).astype(np.float32)
assert (out.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_embedding_look_up1():
params = Tensor(np.array([[8, 9], [10, 11]]), mstype.float32)
indices = Tensor(np.array([2, 2, 1, 0]), mstype.int32)
offset = 0
embedding = Net(offset)
out = embedding(params, indices)
expect = np.array([[0, 0], [0, 0], [10, 11], [8, 9]]).astype(np.float32)
assert (out.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_embedding_look_up2():
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32)
indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32)
offset = 4
embedding = Net(offset)
out = embedding(params, indices)
expect = np.array([[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).astype(np.float32)
assert (out.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_embedding_look_up3():
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32)
indices = Tensor(np.array([[[5], [2]], [[8], [5]]]), mstype.int32)
offset = 4
embedding = Net(offset)
out = embedding(params, indices)
expect = np.array([[[[10, 11]], [[0, 0]]], [[[0, 0]], [[10, 11]]]]).astype(np.float32)
assert (out.asnumpy() == expect).all()