forked from mindspore-Ecosystem/mindspore
Add LshProjection lite ops.
This commit is contained in:
parent
7425e1699e
commit
501de1b6f6
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct LshProjectionParameter {
|
||||
OpParameter op_parameter_;
|
||||
int lsh_type_;
|
||||
int hash_shape_[2];
|
||||
int in_item_num_;
|
||||
size_t in_item_size_;
|
||||
size_t seed_size_;
|
||||
size_t key_size_;
|
||||
int64_t real_dst_count;
|
||||
int task_id_;
|
||||
int64_t count_unit_;
|
||||
} LshProjectionParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_
|
|
@ -14,12 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/lsh_projection.h"
|
||||
#include "nnacl/lsh_projection_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int LshProjection::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; }
|
||||
int LshProjection::GetLshType() const { return this->primitive_->value.AsLshProjection()->type; }
|
||||
#else
|
||||
int LshProjection::GetLshType() const { return this->primitive_->value_as_LshProjection()->type(); }
|
||||
|
||||
int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
@ -29,9 +33,51 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
|
|||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kSparseType = 1;
|
||||
constexpr int kDenseType = 2;
|
||||
} // namespace
|
||||
int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
PrimitiveC::InferShape(inputs_, outputs_);
|
||||
return RET_INFER_INVALID;
|
||||
if (inputs_.size() != kDoubleNum || inputs_.size() != kMultiNum) {
|
||||
MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto in_hash = inputs_.at(kSingleNum);
|
||||
MS_ASSERT(in_hash->shape().size() == 2);
|
||||
MS_ASSERT(in_hash->DimensionSize(1) <= 32);
|
||||
MS_ASSERT(inputs_.at(kDoubleNum)->shape().size() >= 1);
|
||||
|
||||
if (inputs_.size() == kMultiNum) {
|
||||
MS_ASSERT(inputs_.at(kMultiNum)->shape().size() == 1);
|
||||
MS_ASSERT(inputs_.at(kMultiNum)->DimensionSize(0) == in_value->DimensionSize(0));
|
||||
}
|
||||
|
||||
auto out_tensor = outputs_.front();
|
||||
out_tensor->set_data_type(kNumberTypeInt32);
|
||||
out_tensor->SetFormat(schema::Format::Format_NHWC);
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
switch (GetLshType()) {
|
||||
case kSparseType:
|
||||
out_shape.push_back(in_hash->DimensionSize(0));
|
||||
break;
|
||||
case kDenseType:
|
||||
out_shape.push_back(in_hash->DimensionSize(0) * in_hash->DimensionSize(1));
|
||||
break;
|
||||
default:
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_tensor->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ class LshProjection : public PrimitiveC {
|
|||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override;
|
||||
int GetLshType() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
#include "src/ops/resize.h"
|
||||
#include "src/ops/tile.h"
|
||||
#include "src/ops/one_hot.h"
|
||||
#include "src/ops/lsh_projection.h"
|
||||
#include "src/ops/space_to_depth.h"
|
||||
#include "src/ops/split.h"
|
||||
#include "src/ops/argmax.h"
|
||||
|
@ -131,6 +132,7 @@
|
|||
#include "nnacl/unstack.h"
|
||||
#include "nnacl/depth_to_space.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/lsh_projection_parameter.h"
|
||||
#include "nnacl/fp32/pooling.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/fp32/roi_pooling.h"
|
||||
|
@ -1323,6 +1325,20 @@ OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
return reinterpret_cast<OpParameter *>(crop_param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
LshProjectionParameter *lsh_project_param =
|
||||
reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter)));
|
||||
if (lsh_project_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LshProjectionParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(lsh_project_param, 0, sizeof(LshProjectionParameter));
|
||||
lsh_project_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::LshProjection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
lsh_project_param->lsh_type_ = param->GetLshType();
|
||||
return reinterpret_cast<OpParameter *>(lsh_project_param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
|
||||
if (one_hot_param == nullptr) {
|
||||
|
@ -1747,6 +1763,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
|
|||
populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter;
|
||||
}
|
||||
|
||||
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/lsh_projection.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_LshProjection;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
namespace {
|
||||
constexpr int kSparseType = 1;
|
||||
constexpr int kDenseType = 2;
|
||||
} // namespace
|
||||
|
||||
int LshProjectionCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int LshProjectionCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int LshProjectionCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_tensor0 = in_tensors_.at(0);
|
||||
auto input_tensor1 = in_tensors_.at(1);
|
||||
auto out_tensor0 = out_tensors_.at(0);
|
||||
|
||||
hash = reinterpret_cast<float *>(input_tensor0->MutableData());
|
||||
in_data = reinterpret_cast<char *>(input_tensor1->MutableData());
|
||||
weight = in_tensors_.size() == 2 ? nullptr : reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
|
||||
output = reinterpret_cast<int32_t *>(out_tensor0->MutableData());
|
||||
|
||||
const size_t seed_size = sizeof(float);
|
||||
const size_t input_item_size =
|
||||
input_tensor1->ElementsNum() * sizeof(input_tensor1->data_type()) / input_tensor1->DimensionSize(0);
|
||||
const size_t key_size = seed_size + input_item_size;
|
||||
lsh_param_->seed_size_ = seed_size;
|
||||
lsh_param_->in_item_size_ = input_item_size;
|
||||
lsh_param_->key_size_ = key_size;
|
||||
lsh_param_->in_item_num_ = input_tensor1->DimensionSize(0);
|
||||
memcpy(lsh_param_->hash_shape_, input_tensor0->shape().data(), sizeof(int) * input_tensor0->shape().size());
|
||||
|
||||
elements_num_ = input_tensor0->DimensionSize(0);
|
||||
count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_;
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int LshProjectionRun(void *cdata, int task_id) {
|
||||
auto lsh_projection = reinterpret_cast<LshProjectionCPUKernel *>(cdata);
|
||||
lsh_projection->DoExecute(task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LshProjectionCPUKernel::DoExecute(int task_id) {
|
||||
int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
|
||||
lsh_param_->real_dst_count = real_dst_count;
|
||||
lsh_param_->task_id_ = task_id;
|
||||
lsh_param_->count_unit_ = count_unit_;
|
||||
if (real_dst_count <= 0) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
switch (lsh_param_->lsh_type_) {
|
||||
case kSparseType:
|
||||
LshProjectionSparse(hash, in_data, weight, output, lsh_param_);
|
||||
break;
|
||||
case kDenseType:
|
||||
LshProjectionDense(hash, in_data, weight, output, lsh_param_);
|
||||
break;
|
||||
default:
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para) {
|
||||
double score = 0.0;
|
||||
for (int i = 0; i < para->in_item_num_; i++) {
|
||||
char *key = static_cast<char *>(ctx_->allocator->Malloc(lsh_param_->key_size_));
|
||||
if (key == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc key failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(key, &seed, para->seed_size_);
|
||||
memcpy(key + para->seed_size_, in_data, para->in_item_size_);
|
||||
in_data += para->in_item_size_;
|
||||
double hash_sign = static_cast<double>(mindspore::lite::StringHash64(key, para->key_size_));
|
||||
if (weight == nullptr) {
|
||||
score += hash_sign;
|
||||
} else {
|
||||
score += weight[i] * hash_sign;
|
||||
}
|
||||
ctx_->allocator->Free(key);
|
||||
}
|
||||
return (score > 0) ? 1 : 0;
|
||||
}
|
||||
|
||||
void LshProjectionCPUKernel::LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output,
|
||||
LshProjectionParameter *para) {
|
||||
int start = para->task_id_ * para->count_unit_;
|
||||
int end = start + para->real_dst_count;
|
||||
for (int i = start; i < end; i++) {
|
||||
int32_t hash_sign = 0;
|
||||
for (int j = 0; j < para->hash_shape_[1]; j++) {
|
||||
int bit = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para);
|
||||
hash_sign = (hash_sign << 1) | bit;
|
||||
}
|
||||
output[i] = hash_sign + i * (1 << para->hash_shape_[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void LshProjectionCPUKernel::LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output,
|
||||
LshProjectionParameter *para) {
|
||||
int start = para->task_id_ * para->count_unit_;
|
||||
int end = start + para->real_dst_count;
|
||||
for (int i = start; i < end; i++) {
|
||||
for (int j = 0; j < para->hash_shape_[1]; j++) {
|
||||
output[i * para->hash_shape_[1] + j] = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuLshProjectionFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *op_parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_LshProjection);
|
||||
auto *kernel = new (std::nothrow) LshProjectionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new LshProjectionCPUKernel fail!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LshProjection, CpuLshProjectionFp32KernelCreator)
|
||||
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "nnacl/lsh_projection_parameter.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class LshProjectionCPUKernel : public LiteKernel {
|
||||
public:
|
||||
LshProjectionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {
|
||||
lsh_param_ = reinterpret_cast<LshProjectionParameter *>(op_parameter_);
|
||||
}
|
||||
~LshProjectionCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
int GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para);
|
||||
void LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param);
|
||||
void LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param);
|
||||
|
||||
private:
|
||||
LshProjectionParameter *lsh_param_ = nullptr;
|
||||
const lite::InnerContext *ctx_;
|
||||
int thread_num_;
|
||||
int64_t elements_num_;
|
||||
int64_t count_unit_;
|
||||
float *hash = nullptr;
|
||||
char *in_data = nullptr;
|
||||
float *weight = nullptr;
|
||||
int32_t *output = nullptr;
|
||||
};
|
||||
|
||||
int LshProjectionRun(void *cdata, int task_id);
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/nnacl/lsh_projection_parameter.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/lite_kernel.h"
|
||||
#include "mindspore/lite/src/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
namespace {
|
||||
constexpr int kSparseType = 1;
|
||||
constexpr int kDenseType = 2;
|
||||
} // namespace
|
||||
|
||||
class TestLshProjectionFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestLshProjectionFp32() {}
|
||||
};
|
||||
|
||||
TEST_F(TestLshProjectionFp32, Dense1DInputs) {
|
||||
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
|
||||
lite::Tensor in_tensor1(kNumberTypeInt32, {5});
|
||||
lite::Tensor in_tensor2(kNumberTypeFloat, {5});
|
||||
lite::Tensor out_tensor(kNumberTypeInt32, {6});
|
||||
|
||||
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
|
||||
int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678};
|
||||
float input_data2[] = {1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
int32_t output_data[6] = {0};
|
||||
in_tensor0.SetData(input_data0);
|
||||
in_tensor1.SetData(input_data1);
|
||||
in_tensor2.SetData(input_data2);
|
||||
out_tensor.SetData(output_data);
|
||||
|
||||
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2};
|
||||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
LshProjectionParameter parameter = {};
|
||||
parameter.lsh_type_ = kDenseType;
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::InnerContext>();
|
||||
ctx->thread_num_ = 3;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
std::vector<int32_t> except_result = {0, 0, 0, 1, 0, 0};
|
||||
PrintData("output data", output_data, 6);
|
||||
CompareOutputData(output_data, except_result.data(), 6, 0.000001);
|
||||
|
||||
in_tensor0.SetData(nullptr);
|
||||
in_tensor1.SetData(nullptr);
|
||||
out_tensor.SetData(nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestLshProjectionFp32, Sparse1DInputs) {
|
||||
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
|
||||
lite::Tensor in_tensor1(kNumberTypeInt32, {5});
|
||||
lite::Tensor out_tensor(kNumberTypeInt32, {3});
|
||||
|
||||
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
|
||||
int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678};
|
||||
int32_t output_data[3] = {0};
|
||||
in_tensor0.SetData(input_data0);
|
||||
in_tensor1.SetData(input_data1);
|
||||
out_tensor.SetData(output_data);
|
||||
|
||||
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1};
|
||||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
LshProjectionParameter parameter = {};
|
||||
parameter.lsh_type_ = kSparseType;
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::InnerContext>();
|
||||
ctx->thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
std::vector<int32_t> except_result = {0, 5, 8};
|
||||
PrintData("output data", output_data, 3);
|
||||
CompareOutputData(output_data, except_result.data(), 3, 0.000001);
|
||||
|
||||
in_tensor0.SetData(nullptr);
|
||||
in_tensor1.SetData(nullptr);
|
||||
out_tensor.SetData(nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestLshProjectionFp32, Sparse3DInputs) {
|
||||
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
|
||||
lite::Tensor in_tensor1(kNumberTypeInt32, {5, 2, 2});
|
||||
lite::Tensor in_tensor2(kNumberTypeFloat, {5});
|
||||
lite::Tensor out_tensor(kNumberTypeInt32, {3});
|
||||
|
||||
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
|
||||
int32_t input_data1[] = {1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
|
||||
9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543};
|
||||
float input_data2[] = {0.12, 0.34, 0.56, 0.67, 0.78};
|
||||
int32_t output_data[3] = {0};
|
||||
in_tensor0.SetData(input_data0);
|
||||
in_tensor1.SetData(input_data1);
|
||||
in_tensor2.SetData(input_data2);
|
||||
out_tensor.SetData(output_data);
|
||||
|
||||
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2};
|
||||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
LshProjectionParameter parameter = {};
|
||||
parameter.lsh_type_ = kSparseType;
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::InnerContext>();
|
||||
ctx->thread_num_ = 3;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
std::vector<int32_t> except_result = {2, 5, 9};
|
||||
PrintData("output data", output_data, 3);
|
||||
CompareOutputData(output_data, except_result.data(), 3, 0.000001);
|
||||
|
||||
in_tensor0.SetData(nullptr);
|
||||
in_tensor1.SetData(nullptr);
|
||||
out_tensor.SetData(nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -290,7 +290,7 @@ TEST_F(TestMulInt8, Mul_quant1_thread1) {
|
|||
MulParameter op_param;
|
||||
op_param.op_parameter_.type_ = schema::PrimitiveType_Mul;
|
||||
lite::InnerContext *ctx = new lite::InnerContext;
|
||||
ctx->thread_num_ = 2;
|
||||
ctx->thread_num_ = 3;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Mul};
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
|
|
Loading…
Reference in New Issue