add EmbeddingLookup kernel

This commit is contained in:
TFBunny 2021-05-20 14:51:59 -04:00
parent d233eeb87c
commit 250a402e22
8 changed files with 487 additions and 1 deletions

View File

@ -0,0 +1,144 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/embedding_lookup_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EmbeddingLookupKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
EmbeddingLookupKernel, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
EmbeddingLookupKernel, int16_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
EmbeddingLookupKernel, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
EmbeddingLookupKernel, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
EmbeddingLookupKernel, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
EmbeddingLookupKernel, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
EmbeddingLookupKernel, uint8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernel, bool, int)
MS_REG_GPU_KERNEL_TWO(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernel, bool, int64_t)
// dynamic shape
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernel, double, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
EmbeddingLookupKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernel, float, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookupKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernel, half, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
EmbeddingLookupKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernel, bool, int)
MS_REG_GPU_KERNEL_TWO(EmbeddingLookup,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
EmbeddingLookupKernel, bool, int64_t)
// dynamic shape ends
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,148 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/embedding_lookup_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class EmbeddingLookupKernel : public GpuKernel {
public:
EmbeddingLookupKernel() { ResetResource(); }
~EmbeddingLookupKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_shape_) {
int64_t *offset_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&offset_, offset_device_address, sizeof(int64_t),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync offset_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
"cudaDeviceSyncFailed - EmbeddingLookup - in dynamic mode");
}
auto input_dim1 = input_shapes_[0];
CalEmbeddingLookup(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, offset_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 3) {
is_dynamic_shape_ = true;
MS_LOG(INFO) << " EmbeddingLookup running in Dynamic Mode.";
} else if (input_num == 2) {
MS_LOG(INFO) << " EmbeddingLookup running in Normal Mode.";
} else {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookup needs 2 or 3.";
}
input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
if (!is_dynamic_shape_) {
offset_ = GetAttr<int64_t>(kernel_node, "offset");
}
Reshape();
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
is_dynamic_shape_ = false;
input_shapes_.clear();
indices_shapes_.clear();
output_shapes_.clear();
std::fill(dims_, dims_ + 3, 0);
offset_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t size = GetSize(input_shapes_);
input_size_list_.push_back(size);
size = GetSize(indices_shapes_);
input_size_list_.push_back(size);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(int64_t));
}
size = GetSize(output_shapes_);
output_size_list_.push_back(size);
}
private:
void Reshape() {
int64_t axis = 0;
size_t dim_before_axis = 1;
for (size_t i = 0; i < LongToSize(axis); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_indices = 1;
for (size_t i = 0; i < indices_shapes_.size(); i++) {
dim_of_indices *= indices_shapes_[i];
}
size_t dim_after_indices = 1;
for (size_t i = LongToSize(axis) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_indices;
dims_[2] = dim_after_indices;
return;
}
size_t GetSize(const std::vector<size_t> &shape) const {
if (shape.size() == 0) {
return 0;
}
size_t result = sizeof(T);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
std::vector<size_t> input_shapes_;
std::vector<size_t> indices_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
int64_t offset_;
bool is_dynamic_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_

View File

@ -73,6 +73,20 @@ MS_REG_GPU_KERNEL_TWO(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherV2GpuFwdKernel, bool, int64_t)
// dynamic shape
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2GpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2GpuFwdKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/embedding_lookup_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void SubOffset(T *indices, size_t size, int64_t offset) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
indices[pos] -= static_cast<T>(offset);
}
return;
}
template <typename T, typename S>
void CalEmbeddingLookup(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, int64_t offset, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
SubOffset<<<GET_BLOCKS(output_dim1), GET_THREADS, 0, stream>>>(indices, output_dim1, offset);
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
output_dim2, input_dim1);
// restore indices
SubOffset<<<GET_BLOCKS(output_dim1), GET_THREADS, 0, stream>>>(indices, output_dim1, -offset);
return;
}
template void CalEmbeddingLookup<float, int>(float *input, int *indices, float *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<float, int64_t>(float *input, int64_t *indices, float *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<half, int>(half *input, int *indices, half *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<double, int>(double *input, int *indices, double *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<double, int64_t>(double *input, int64_t *indices, double *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<int, int>(int *input, int *indices, int *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<int, int64_t>(int *input, int64_t *indices, int *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<int16_t, int64_t>(int16_t *input, int64_t *indices, int16_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output,
size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, int64_t offset, cudaStream_t stream);
template void CalEmbeddingLookup<bool, int>(bool *input, int *indices, bool *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, int64_t offset,
cudaStream_t stream);
template void CalEmbeddingLookup<bool, int64_t>(bool *input, int64_t *indices, bool *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
int64_t offset, cudaStream_t stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EMBEDDING_LOOKUP_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EMBEDDING_LOOKUP_IMPL_CUH_
template <typename T, typename S>
void CalEmbeddingLookup(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, int64_t offset, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EMBEDDING_LOOKUP_IMPL_CUH_

View File

@ -16,8 +16,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
template <typename T, typename S>
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream);
template <typename T, typename S>
__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_

View File

@ -5060,7 +5060,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
ValueError: If length of shape of `input_params` is greater than 2.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.embeddinglookup = P.EmbeddingLookup()
def construct(self, input_params, input_indices, offset):
return self.embeddinglookup(input_params, input_indices, offset)
def embeddinglookup_testcase(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]).astype(nptype))
input_indices = Tensor(np.array([[5, 2], [8, 5]]).astype(np.int32))
offset = 4
output = Net()(input_params, input_indices, offset)
expect = np.array([[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).astype(nptype)
np.testing.assert_almost_equal(expect, output.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]).astype(nptype))
input_indices = Tensor(np.array([[5, 2], [8, 5]]).astype(np.int32))
offset = 4
output = Net()(input_params, input_indices, offset)
expect = np.array([[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).astype(nptype)
np.testing.assert_almost_equal(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_embeddinglookup_float32():
embeddinglookup_testcase(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_embeddinglookup_float16():
embeddinglookup_testcase(np.float16)