add string_util & cpu op skip_gram

This commit is contained in:
tao_yunhao 2020-09-29 16:27:56 +08:00
parent 11a04a889b
commit caf25aa4d3
13 changed files with 589 additions and 0 deletions

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_
#define MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_
#include "nnacl/op_base.h"
typedef struct {
OpParameter op_parameter_;
int ngram_size;
int max_skip_size;
bool include_all_ngrams;
} SkipGramParameter;
#endif // MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_

View File

@ -19,6 +19,7 @@ endif ()
set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.c

View File

@ -0,0 +1,106 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/string_util.h"
namespace mindspore {
namespace lite {
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
if (tensor->MutableData() == nullptr) {
MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
return std::vector<StringPack>{};
}
return ParseStringBuffer(tensor->MutableData());
}
std::vector<StringPack> ParseStringBuffer(const void *data) {
const int32_t *offset = reinterpret_cast<const int32_t *>(data);
int32_t num = *offset;
std::vector<StringPack> buffer;
for (int i = 0; i < num; i++) {
offset += 1;
buffer.push_back(StringPack{(*(offset + 1)) - (*offset), reinterpret_cast<const char *>(data) + (*offset)});
}
return buffer;
}
int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer) {
int32_t num = string_buffer.size();
std::vector<int32_t> offset(num + 1);
offset[0] = 4 * (num + 2);
for (int i = 0; i < num; i++) {
offset[i + 1] = offset[i] + string_buffer[i].len;
}
std::vector<int> shape = {offset[num]};
tensor->set_shape(shape);
void *data = tensor->MutableData();
if (data == nullptr) {
return RET_ERROR;
}
auto *string_info = reinterpret_cast<int32_t *>(data);
auto *string_data = reinterpret_cast<char *>(data);
string_info[0] = num;
for (int i = 0; i <= num; i++) {
string_info[i + 1] = offset[i];
}
for (int i = 0; i < num; i++) {
memcpy(string_data + offset[i], string_buffer[i].data, string_buffer[i].len);
}
return RET_OK;
}
int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer) {
int32_t num = string_buffer.size();
std::vector<int32_t> offset(num + 1);
offset[0] = 4 * (num + 2);
std::vector<int> len(num);
for (int i = 0; i < num; i++) {
len[i] = 0;
for (int j = 0; j < static_cast<int>(string_buffer[i].size()); j++) {
len[i] += string_buffer[i][j].len;
}
offset[i + 1] = offset[i] + len[i];
}
std::vector<int> shape = {offset[num]};
tensor->set_shape(shape);
void *data = tensor->MutableData();
if (data == nullptr) {
return RET_ERROR;
}
auto *string_info = reinterpret_cast<int32_t *>(data);
auto *string_data = reinterpret_cast<char *>(data);
string_info[0] = num;
for (int i = 0; i <= num; i++) {
string_info[i + 1] = offset[i];
}
for (int i = 0; i < num; i++) {
auto *dst = string_data + offset[i];
for (auto string_part : string_buffer[i]) {
memcpy(dst, string_part.data, string_part.len);
dst += string_part.len;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_COMMON_STRING_UTIL_H_
#define MINDSPORE_LITE_COMMON_STRING_UTIL_H_
#include <vector>
#include <string>
#include <utility>
#include "mindspore/lite/src/tensor.h"
#include "src/common/log_adapter.h"
#include "tools/common/option.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
typedef struct {
int len;
const char *data;
} StringPack;
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor);
std::vector<StringPack> ParseStringBuffer(const void *data);
int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer);
int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_COMMON_STRING_UTIL_H_

View File

@ -0,0 +1,74 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/skip_gram.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int SkipGram::GetNgramSize() const { return this->primitive_->value.AsSkipGram()->ngramSize; }
int SkipGram::GetMaxSkipSize() const { return this->primitive_->value.AsSkipGram()->maxSkipSize; }
bool SkipGram::GetIncludeAllNgrams() const { return this->primitive_->value.AsSkipGram()->includeAllGrams; }
void SkipGram::SetNgramSize(int ngram_size) { this->primitive_->value.AsSkipGram()->ngramSize = ngram_size; }
void SkipGram::SetMaxSkipSize(int max_skip_size) { this->primitive_->value.AsSkipGram()->maxSkipSize = max_skip_size; }
void SkipGram::SetIncludeAllNgrams(bool include_all_ngrams) {
this->primitive_->value.AsSkipGram()->includeAllGrams = include_all_ngrams;
}
#else
int SkipGram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_SkipGram();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_SkipGram return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateSkipGram(*fbb, attr->ngramSize(), attr->maxSkipSize(), attr->includeAllGrams());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SkipGram, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int SkipGram::GetNgramSize() const { return this->primitive_->value_as_SkipGram()->ngramSize(); }
int SkipGram::GetMaxSkipSize() const { return this->primitive_->value_as_SkipGram()->maxSkipSize(); }
bool SkipGram::GetIncludeAllNgrams() const { return this->primitive_->value_as_SkipGram()->includeAllGrams(); }
#endif
int SkipGram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Skip Gram should have one input";
return RET_INPUT_TENSOR_ERROR;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Skip Gram should have one outputs";
return RET_INPUT_TENSOR_ERROR;
}
auto input = inputs_.front();
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return RET_INFER_INVALID;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_
#define LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class SkipGram : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(SkipGram, PrimitiveC);
SkipGram() = default;
explicit SkipGram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNgramSize(int ngram_size);
void SetMaxSkipSize(int max_skip_size);
void SetIncludeAllNgrams(bool include_all_ngrams);
#else
SkipGram() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetNgramSize() const;
int GetMaxSkipSize() const;
bool GetIncludeAllNgrams() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_

View File

@ -115,6 +115,7 @@
#include "src/ops/l2_norm.h"
#include "src/ops/neg.h"
#include "src/ops/detection_post_process.h"
#include "src/ops/skip_gram.h"
#include "nnacl/op_base.h"
#include "nnacl/fp32/arg_min_max.h"
#include "nnacl/fp32/cast.h"
@ -174,6 +175,7 @@
#include "nnacl/l2_norm_parameter.h"
#include "nnacl/detection_post_process_parameter.h"
#include "nnacl/fp32/exp.h"
#include "nnacl/fp32/skip_gram.h"
namespace mindspore::kernel {
@ -1586,6 +1588,21 @@ OpParameter *PopulateExpParameter(const mindspore::lite::PrimitiveC *primitive)
return reinterpret_cast<OpParameter *>(exp_parameter);
}
OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primitive) {
SkipGramParameter *skipGramParameter = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter)));
if (skipGramParameter == nullptr) {
MS_LOG(ERROR) << "malloc SkipGramParameter failed.";
return nullptr;
}
memset(skipGramParameter, 0, sizeof(SkipGramParameter));
skipGramParameter->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::SkipGram *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
skipGramParameter->ngram_size = param->GetNgramSize();
skipGramParameter->max_skip_size = param->GetMaxSkipSize();
skipGramParameter->include_all_ngrams = param->GetIncludeAllNgrams();
return reinterpret_cast<OpParameter *>(skipGramParameter);
}
PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
@ -1688,6 +1705,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter;
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter;
}
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {

View File

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/skip_gram.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::StringPack;
using mindspore::schema::PrimitiveType_SkipGram;
namespace mindspore::kernel {
int SkipGramCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int SkipGramCPUKernel::ReSize() { return RET_OK; }
void ParseSentenceToWords(const StringPack &sentence, std::vector<StringPack> *words) {
int pre = 0;
int i;
for (i = 0; i < sentence.len; i++) {
if (sentence.data[i] != ' ') {
pre = i;
break;
}
}
for (; i < sentence.len; i++) {
if (sentence.data[i] == ' ') {
if (sentence.data[pre] != ' ') {
words->push_back({i - pre, sentence.data + pre});
}
pre = i + 1;
}
}
if (sentence.data[sentence.len - 1] != ' ') {
words->push_back({sentence.len - pre, sentence.data + pre});
}
}
int SkipGramCPUKernel::Run() {
skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_);
if (skip_gram_parameter_->ngram_size < 1) {
MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get "
<< skip_gram_parameter_->ngram_size;
return RET_ERROR;
}
StringPack sentence = mindspore::lite::ParseTensorBuffer(in_tensors_[0]).at(0);
std::vector<StringPack> words;
ParseSentenceToWords(sentence, &words);
std::vector<std::vector<StringPack>> result;
std::vector<int> stack(skip_gram_parameter_->ngram_size, 0);
int index = 1;
int size = words.size();
while (index >= 0) {
if (index < skip_gram_parameter_->ngram_size && stack[index] + 1 < size &&
(index == 0 || stack[index] - stack[index - 1] <= skip_gram_parameter_->max_skip_size)) {
stack[index]++;
index++;
if (index < skip_gram_parameter_->ngram_size) {
stack[index] = stack[index - 1];
}
} else {
if (index > 0 && ((skip_gram_parameter_->include_all_ngrams && index <= skip_gram_parameter_->ngram_size) ||
(!skip_gram_parameter_->include_all_ngrams && index == skip_gram_parameter_->ngram_size))) {
std::vector<StringPack> gram(2 * index - 1);
char blank[1] = {' '};
StringPack blank_str = {1, blank};
for (int i = 0; i < 2 * index - 2; i += 2) {
gram[i] = words[stack[i / 2]];
gram[i + 1] = blank_str;
}
gram[2 * index - 2] = words[stack[index - 1]];
result.push_back(gram);
}
index--;
}
}
int ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
return ret;
}
kernel::LiteKernel *CpuSkipGramFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr || ctx == nullptr) {
MS_LOG(ERROR) << "parameter or ctx is nullptr";
return nullptr;
}
MS_ASSERT(desc.type == PrimitiveType_SkipGram);
auto *kernel = new (std::nothrow) SkipGramCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_;
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SkipGram, CpuSkipGramFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32/skip_gram.h"
#include "src/common/string_util.h"
namespace mindspore::kernel {
class SkipGramCPUKernel : public LiteKernel {
public:
explicit SkipGramCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {}
~SkipGramCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoExcute(int task_id);
protected:
const lite::InnerContext *ctx_;
int thread_count_;
SkipGramParameter *skip_gram_parameter_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_

View File

@ -132,6 +132,9 @@ class Tensor : public mindspore::tensor::MSTensor {
case kNumberTypeBool:
size = sizeof(bool);
break;
case kObjectTypeString:
size = sizeof(char);
break;
default:
MS_LOG(ERROR) << "Not support the type: " << this->data_type_;
return 0;

View File

@ -124,6 +124,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/common/file_utils.cc
${LITE_DIR}/src/common/file_utils_ext.cc
${LITE_DIR}/src/common/utils.cc
${LITE_DIR}/src/common/string_util.cc
${LITE_DIR}/tools/common/graph_util.cc
${LITE_DIR}/tools/common/tensor_util.cc
${LITE_DIR}/tools/common/node_util.cc

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "src/runtime/kernel/arm/fp32/skip_gram.h"
#include "nnacl/fp32/skip_gram.h"
#include "src/common/file_utils.h"
#include "common/common_test.h"
#include "src/common/log_adapter.h"
#include "src/common/string_util.h"
namespace mindspore {
using mindspore::lite::StringPack;
using mindspore::lite::Tensor;
class TestSkipGramFp32 : public mindspore::CommonTest {
public:
TestSkipGramFp32() {}
};
void SkipGramTestInit(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_,
SkipGramParameter *skip_gram_param) {
Tensor *in_t_first = new Tensor(kObjectTypeString, {}, schema::Format_NHWC, lite::Tensor::Category::CONST);
char sentence[] = "The quick brown fox jumps over the lazy dog";
std::vector<StringPack> str;
str.push_back({43, sentence});
mindspore::lite::WriteStringsToTensor(in_t_first, str);
inputs_->push_back(in_t_first);
Tensor *output = new Tensor(kObjectTypeString, {}, schema::Format_NHWC, lite::Tensor::Category::CONST);
outputs_->push_back(output);
skip_gram_param->ngram_size = 3;
skip_gram_param->max_skip_size = 2;
skip_gram_param->include_all_ngrams = true;
skip_gram_param->op_parameter_.type_ = mindspore::schema::PrimitiveType_SkipGram;
skip_gram_param->op_parameter_.thread_num_ = 2;
}
TEST_F(TestSkipGramFp32, ElTest) {
std::vector<Tensor *> inputs_;
std::vector<Tensor *> outputs_;
auto skip_gram_param_ = new SkipGramParameter();
SkipGramTestInit(&inputs_, &outputs_, skip_gram_param_);
lite::InnerContext *ctx = new lite::InnerContext;
ctx->thread_num_ = 2;
ASSERT_EQ(lite::RET_OK, ctx->Init());
kernel::SkipGramCPUKernel *el =
new kernel::SkipGramCPUKernel(reinterpret_cast<OpParameter *>(skip_gram_param_), inputs_, outputs_, ctx, nullptr);
el->Init();
el->Run();
std::vector<StringPack> output = mindspore::lite::ParseTensorBuffer(outputs_[0]);
for (int i = 0; i < output.size(); i++) {
for (int j = 0; j < output[i].len; j++) {
printf("%c", output[i].data[j]);
}
printf("\n");
}
}
} // namespace mindspore

View File

@ -17,6 +17,7 @@ if (WIN32)
${CMAKE_CURRENT_SOURCE_DIR}/../../src/inner_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/string_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/runtime_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/thread_pool.c
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/workspace_pool.cc
@ -42,6 +43,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
@ -75,6 +77,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core)
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
set(LITE_SRC
${SRC_DIR}/common/graph_util.cc
${SRC_DIR}/common/string_util.cc
${SRC_DIR}/runtime/allocator.cc
${SRC_DIR}/runtime/runtime_api.cc
${SRC_DIR}/runtime/thread_pool.c