forked from mindspore-Ecosystem/mindspore
add string_util & cpu op skip_gram
This commit is contained in:
parent
11a04a889b
commit
caf25aa4d3
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct {
|
||||
OpParameter op_parameter_;
|
||||
int ngram_size;
|
||||
int max_skip_size;
|
||||
bool include_all_ngrams;
|
||||
} SkipGramParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_SKIP_GRAM_H_
|
|
@ -19,6 +19,7 @@ endif ()
|
|||
set(LITE_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.c
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
|
||||
if (tensor->MutableData() == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
|
||||
return std::vector<StringPack>{};
|
||||
}
|
||||
return ParseStringBuffer(tensor->MutableData());
|
||||
}
|
||||
|
||||
std::vector<StringPack> ParseStringBuffer(const void *data) {
|
||||
const int32_t *offset = reinterpret_cast<const int32_t *>(data);
|
||||
int32_t num = *offset;
|
||||
std::vector<StringPack> buffer;
|
||||
for (int i = 0; i < num; i++) {
|
||||
offset += 1;
|
||||
buffer.push_back(StringPack{(*(offset + 1)) - (*offset), reinterpret_cast<const char *>(data) + (*offset)});
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer) {
|
||||
int32_t num = string_buffer.size();
|
||||
std::vector<int32_t> offset(num + 1);
|
||||
offset[0] = 4 * (num + 2);
|
||||
for (int i = 0; i < num; i++) {
|
||||
offset[i + 1] = offset[i] + string_buffer[i].len;
|
||||
}
|
||||
std::vector<int> shape = {offset[num]};
|
||||
tensor->set_shape(shape);
|
||||
void *data = tensor->MutableData();
|
||||
if (data == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *string_info = reinterpret_cast<int32_t *>(data);
|
||||
auto *string_data = reinterpret_cast<char *>(data);
|
||||
|
||||
string_info[0] = num;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
string_info[i + 1] = offset[i];
|
||||
}
|
||||
for (int i = 0; i < num; i++) {
|
||||
memcpy(string_data + offset[i], string_buffer[i].data, string_buffer[i].len);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer) {
|
||||
int32_t num = string_buffer.size();
|
||||
std::vector<int32_t> offset(num + 1);
|
||||
offset[0] = 4 * (num + 2);
|
||||
std::vector<int> len(num);
|
||||
for (int i = 0; i < num; i++) {
|
||||
len[i] = 0;
|
||||
for (int j = 0; j < static_cast<int>(string_buffer[i].size()); j++) {
|
||||
len[i] += string_buffer[i][j].len;
|
||||
}
|
||||
offset[i + 1] = offset[i] + len[i];
|
||||
}
|
||||
|
||||
std::vector<int> shape = {offset[num]};
|
||||
tensor->set_shape(shape);
|
||||
void *data = tensor->MutableData();
|
||||
if (data == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *string_info = reinterpret_cast<int32_t *>(data);
|
||||
auto *string_data = reinterpret_cast<char *>(data);
|
||||
|
||||
string_info[0] = num;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
string_info[i + 1] = offset[i];
|
||||
}
|
||||
for (int i = 0; i < num; i++) {
|
||||
auto *dst = string_data + offset[i];
|
||||
for (auto string_part : string_buffer[i]) {
|
||||
memcpy(dst, string_part.data, string_part.len);
|
||||
dst += string_part.len;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_COMMON_STRING_UTIL_H_
|
||||
#define MINDSPORE_LITE_COMMON_STRING_UTIL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "mindspore/lite/src/tensor.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/option.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
typedef struct {
|
||||
int len;
|
||||
const char *data;
|
||||
} StringPack;
|
||||
|
||||
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor);
|
||||
std::vector<StringPack> ParseStringBuffer(const void *data);
|
||||
|
||||
int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer);
|
||||
int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_COMMON_STRING_UTIL_H_
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/skip_gram.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int SkipGram::GetNgramSize() const { return this->primitive_->value.AsSkipGram()->ngramSize; }
|
||||
int SkipGram::GetMaxSkipSize() const { return this->primitive_->value.AsSkipGram()->maxSkipSize; }
|
||||
bool SkipGram::GetIncludeAllNgrams() const { return this->primitive_->value.AsSkipGram()->includeAllGrams; }
|
||||
|
||||
void SkipGram::SetNgramSize(int ngram_size) { this->primitive_->value.AsSkipGram()->ngramSize = ngram_size; }
|
||||
void SkipGram::SetMaxSkipSize(int max_skip_size) { this->primitive_->value.AsSkipGram()->maxSkipSize = max_skip_size; }
|
||||
void SkipGram::SetIncludeAllNgrams(bool include_all_ngrams) {
|
||||
this->primitive_->value.AsSkipGram()->includeAllGrams = include_all_ngrams;
|
||||
}
|
||||
|
||||
#else
|
||||
int SkipGram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
||||
auto attr = primitive->value_as_SkipGram();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_SkipGram return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateSkipGram(*fbb, attr->ngramSize(), attr->maxSkipSize(), attr->includeAllGrams());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SkipGram, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SkipGram::GetNgramSize() const { return this->primitive_->value_as_SkipGram()->ngramSize(); }
|
||||
int SkipGram::GetMaxSkipSize() const { return this->primitive_->value_as_SkipGram()->maxSkipSize(); }
|
||||
bool SkipGram::GetIncludeAllNgrams() const { return this->primitive_->value_as_SkipGram()->includeAllGrams(); }
|
||||
|
||||
#endif
|
||||
|
||||
int SkipGram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "Skip Gram should have one input";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "Skip Gram should have one outputs";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SkipGram : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(SkipGram, PrimitiveC);
|
||||
SkipGram() = default;
|
||||
explicit SkipGram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetNgramSize(int ngram_size);
|
||||
void SetMaxSkipSize(int max_skip_size);
|
||||
void SetIncludeAllNgrams(bool include_all_ngrams);
|
||||
|
||||
#else
|
||||
SkipGram() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetNgramSize() const;
|
||||
int GetMaxSkipSize() const;
|
||||
bool GetIncludeAllNgrams() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_SKIP_GRAM_H_
|
|
@ -115,6 +115,7 @@
|
|||
#include "src/ops/l2_norm.h"
|
||||
#include "src/ops/neg.h"
|
||||
#include "src/ops/detection_post_process.h"
|
||||
#include "src/ops/skip_gram.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/fp32/arg_min_max.h"
|
||||
#include "nnacl/fp32/cast.h"
|
||||
|
@ -174,6 +175,7 @@
|
|||
#include "nnacl/l2_norm_parameter.h"
|
||||
#include "nnacl/detection_post_process_parameter.h"
|
||||
#include "nnacl/fp32/exp.h"
|
||||
#include "nnacl/fp32/skip_gram.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
|
@ -1586,6 +1588,21 @@ OpParameter *PopulateExpParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
return reinterpret_cast<OpParameter *>(exp_parameter);
|
||||
}
|
||||
|
||||
OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
SkipGramParameter *skipGramParameter = reinterpret_cast<SkipGramParameter *>(malloc(sizeof(SkipGramParameter)));
|
||||
if (skipGramParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc SkipGramParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(skipGramParameter, 0, sizeof(SkipGramParameter));
|
||||
skipGramParameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::SkipGram *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
skipGramParameter->ngram_size = param->GetNgramSize();
|
||||
skipGramParameter->max_skip_size = param->GetMaxSkipSize();
|
||||
skipGramParameter->include_all_ngrams = param->GetIncludeAllNgrams();
|
||||
return reinterpret_cast<OpParameter *>(skipGramParameter);
|
||||
}
|
||||
|
||||
PopulateParameterRegistry::PopulateParameterRegistry() {
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
|
||||
|
@ -1688,6 +1705,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
|
|||
populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter;
|
||||
}
|
||||
|
||||
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/skip_gram.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::StringPack;
|
||||
using mindspore::schema::PrimitiveType_SkipGram;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int SkipGramCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int SkipGramCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
void ParseSentenceToWords(const StringPack &sentence, std::vector<StringPack> *words) {
|
||||
int pre = 0;
|
||||
int i;
|
||||
for (i = 0; i < sentence.len; i++) {
|
||||
if (sentence.data[i] != ' ') {
|
||||
pre = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (; i < sentence.len; i++) {
|
||||
if (sentence.data[i] == ' ') {
|
||||
if (sentence.data[pre] != ' ') {
|
||||
words->push_back({i - pre, sentence.data + pre});
|
||||
}
|
||||
pre = i + 1;
|
||||
}
|
||||
}
|
||||
if (sentence.data[sentence.len - 1] != ' ') {
|
||||
words->push_back({sentence.len - pre, sentence.data + pre});
|
||||
}
|
||||
}
|
||||
|
||||
int SkipGramCPUKernel::Run() {
|
||||
skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_);
|
||||
if (skip_gram_parameter_->ngram_size < 1) {
|
||||
MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get "
|
||||
<< skip_gram_parameter_->ngram_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
StringPack sentence = mindspore::lite::ParseTensorBuffer(in_tensors_[0]).at(0);
|
||||
std::vector<StringPack> words;
|
||||
ParseSentenceToWords(sentence, &words);
|
||||
|
||||
std::vector<std::vector<StringPack>> result;
|
||||
std::vector<int> stack(skip_gram_parameter_->ngram_size, 0);
|
||||
|
||||
int index = 1;
|
||||
int size = words.size();
|
||||
while (index >= 0) {
|
||||
if (index < skip_gram_parameter_->ngram_size && stack[index] + 1 < size &&
|
||||
(index == 0 || stack[index] - stack[index - 1] <= skip_gram_parameter_->max_skip_size)) {
|
||||
stack[index]++;
|
||||
index++;
|
||||
if (index < skip_gram_parameter_->ngram_size) {
|
||||
stack[index] = stack[index - 1];
|
||||
}
|
||||
} else {
|
||||
if (index > 0 && ((skip_gram_parameter_->include_all_ngrams && index <= skip_gram_parameter_->ngram_size) ||
|
||||
(!skip_gram_parameter_->include_all_ngrams && index == skip_gram_parameter_->ngram_size))) {
|
||||
std::vector<StringPack> gram(2 * index - 1);
|
||||
char blank[1] = {' '};
|
||||
StringPack blank_str = {1, blank};
|
||||
for (int i = 0; i < 2 * index - 2; i += 2) {
|
||||
gram[i] = words[stack[i / 2]];
|
||||
gram[i + 1] = blank_str;
|
||||
}
|
||||
gram[2 * index - 2] = words[stack[index - 1]];
|
||||
result.push_back(gram);
|
||||
}
|
||||
index--;
|
||||
}
|
||||
}
|
||||
|
||||
int ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuSkipGramFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::InnerContext *ctx, const KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (parameter == nullptr || ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter or ctx is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == PrimitiveType_SkipGram);
|
||||
auto *kernel = new (std::nothrow) SkipGramCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init Kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SkipGram, CpuSkipGramFp32KernelCreator)
|
||||
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp32/skip_gram.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class SkipGramCPUKernel : public LiteKernel {
|
||||
public:
|
||||
explicit SkipGramCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {}
|
||||
~SkipGramCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExcute(int task_id);
|
||||
|
||||
protected:
|
||||
const lite::InnerContext *ctx_;
|
||||
int thread_count_;
|
||||
SkipGramParameter *skip_gram_parameter_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SKIP_GRAM_H_
|
|
@ -132,6 +132,9 @@ class Tensor : public mindspore::tensor::MSTensor {
|
|||
case kNumberTypeBool:
|
||||
size = sizeof(bool);
|
||||
break;
|
||||
case kObjectTypeString:
|
||||
size = sizeof(char);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Not support the type: " << this->data_type_;
|
||||
return 0;
|
||||
|
|
|
@ -124,6 +124,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/common/file_utils.cc
|
||||
${LITE_DIR}/src/common/file_utils_ext.cc
|
||||
${LITE_DIR}/src/common/utils.cc
|
||||
${LITE_DIR}/src/common/string_util.cc
|
||||
${LITE_DIR}/tools/common/graph_util.cc
|
||||
${LITE_DIR}/tools/common/tensor_util.cc
|
||||
${LITE_DIR}/tools/common/node_util.cc
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "src/runtime/kernel/arm/fp32/skip_gram.h"
|
||||
#include "nnacl/fp32/skip_gram.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::lite::StringPack;
|
||||
using mindspore::lite::Tensor;
|
||||
|
||||
class TestSkipGramFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestSkipGramFp32() {}
|
||||
};
|
||||
|
||||
void SkipGramTestInit(std::vector<Tensor *> *inputs_, std::vector<Tensor *> *outputs_,
|
||||
SkipGramParameter *skip_gram_param) {
|
||||
Tensor *in_t_first = new Tensor(kObjectTypeString, {}, schema::Format_NHWC, lite::Tensor::Category::CONST);
|
||||
char sentence[] = "The quick brown fox jumps over the lazy dog";
|
||||
std::vector<StringPack> str;
|
||||
str.push_back({43, sentence});
|
||||
mindspore::lite::WriteStringsToTensor(in_t_first, str);
|
||||
inputs_->push_back(in_t_first);
|
||||
|
||||
Tensor *output = new Tensor(kObjectTypeString, {}, schema::Format_NHWC, lite::Tensor::Category::CONST);
|
||||
outputs_->push_back(output);
|
||||
|
||||
skip_gram_param->ngram_size = 3;
|
||||
skip_gram_param->max_skip_size = 2;
|
||||
skip_gram_param->include_all_ngrams = true;
|
||||
skip_gram_param->op_parameter_.type_ = mindspore::schema::PrimitiveType_SkipGram;
|
||||
skip_gram_param->op_parameter_.thread_num_ = 2;
|
||||
}
|
||||
|
||||
TEST_F(TestSkipGramFp32, ElTest) {
|
||||
std::vector<Tensor *> inputs_;
|
||||
std::vector<Tensor *> outputs_;
|
||||
|
||||
auto skip_gram_param_ = new SkipGramParameter();
|
||||
SkipGramTestInit(&inputs_, &outputs_, skip_gram_param_);
|
||||
|
||||
lite::InnerContext *ctx = new lite::InnerContext;
|
||||
ctx->thread_num_ = 2;
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
kernel::SkipGramCPUKernel *el =
|
||||
new kernel::SkipGramCPUKernel(reinterpret_cast<OpParameter *>(skip_gram_param_), inputs_, outputs_, ctx, nullptr);
|
||||
|
||||
el->Init();
|
||||
el->Run();
|
||||
|
||||
std::vector<StringPack> output = mindspore::lite::ParseTensorBuffer(outputs_[0]);
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
for (int j = 0; j < output[i].len; j++) {
|
||||
printf("%c", output[i].data[j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -17,6 +17,7 @@ if (WIN32)
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/inner_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/runtime_api.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/thread_pool.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/workspace_pool.cc
|
||||
|
@ -42,6 +43,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
|
||||
|
@ -75,6 +77,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core)
|
|||
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
||||
set(LITE_SRC
|
||||
${SRC_DIR}/common/graph_util.cc
|
||||
${SRC_DIR}/common/string_util.cc
|
||||
${SRC_DIR}/runtime/allocator.cc
|
||||
${SRC_DIR}/runtime/runtime_api.cc
|
||||
${SRC_DIR}/runtime/thread_pool.c
|
||||
|
|
Loading…
Reference in New Issue