forked from mindspore-Ecosystem/mindspore
!7252 add custom op normalize
Merge pull request !7252 from zhaozhenlong/lite/op/normalize
This commit is contained in:
commit
6954aa05a1
|
@ -52,8 +52,8 @@ int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_b
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *string_info = reinterpret_cast<int32_t *>(data);
|
||||
auto *string_data = reinterpret_cast<char *>(data);
|
||||
int32_t *string_info = reinterpret_cast<int32_t *>(data);
|
||||
char *string_data = reinterpret_cast<char *>(data);
|
||||
|
||||
string_info[0] = num;
|
||||
for (int i = 0; i <= num; i++) {
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/custom_normalize.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -30,8 +31,24 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
}
|
||||
#endif
|
||||
int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
PrimitiveC::InferShape(inputs_, outputs_);
|
||||
return RET_INFER_INVALID;
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
if (input->data_c() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
int string_num = lite::GetStringCount(input);
|
||||
auto output = outputs_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> shape;
|
||||
shape.push_back(string_num == 0 ? 1 : string_num);
|
||||
|
||||
output->set_shape(shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/string/normalize.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_CustomNormalize;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
const char kPunctuationsRegex[] = "[.*()\"]";
|
||||
const std::map<std::string, std::string> *kRegexTransforms = new (std::nothrow) std::map<std::string, std::string>({
|
||||
{"([\\S]+)n't", "$1 not"},
|
||||
{"([\\S]+)'nt", "$1 not"},
|
||||
{"([\\S]+)'ll", "$1 will"},
|
||||
{"([\\S]+)'re", "$1 are"},
|
||||
{"([\\S]+)'ve", "$1 have"},
|
||||
{"i'm", "i am"},
|
||||
});
|
||||
const int32_t kMaxStringLength = 300;
|
||||
|
||||
} // namespace
|
||||
|
||||
int NormalizeCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int NormalizeCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
std::string NormalizeCPUKernel::Trim(const std::string &str, const std::string &whitespace /*= " \t\n\v\f\r"*/) {
|
||||
auto begin = str.find_first_not_of(whitespace);
|
||||
auto end = str.find_last_not_of(whitespace);
|
||||
const auto range = end - begin + 1;
|
||||
return str.substr(begin, range);
|
||||
}
|
||||
|
||||
std::string NormalizeCPUKernel::GlobalReplace(const std::string &str, const std::string ®,
|
||||
const std::string &replace) {
|
||||
std::regex e(reg);
|
||||
return std::regex_replace(str, e, replace);
|
||||
}
|
||||
|
||||
std::string NormalizeCPUKernel::Normalize(const std::string &str) {
|
||||
std::string result;
|
||||
std::transform(str.begin(), str.end(), back_inserter(result), [](unsigned char c) { return std::tolower(c); });
|
||||
result = Trim(result);
|
||||
result = GlobalReplace(result, kPunctuationsRegex, "");
|
||||
result = GlobalReplace(result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", "$1$2");
|
||||
result = GlobalReplace(result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "$1");
|
||||
// transform shortening to full
|
||||
MS_ASSERT(kRegexTransforms != nullptr);
|
||||
for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); iter++) {
|
||||
result = GlobalReplace(result, iter->first, iter->second);
|
||||
}
|
||||
result = GlobalReplace(result, "([?])+", "$1");
|
||||
result = GlobalReplace(result, "([!])+", "$1");
|
||||
result = GlobalReplace(result, "([^?!]+)([?!])", "$1 $2 ");
|
||||
result = GlobalReplace(result, "([?!])([?!])", "$1 $2");
|
||||
|
||||
result = GlobalReplace(result, "[\\s,:;\\-&'\"]+$", "");
|
||||
result = GlobalReplace(result, "^[\\s,:;\\-&'\"]+", "");
|
||||
|
||||
result = Trim(result);
|
||||
if (result.size() > kMaxStringLength) {
|
||||
result = result.substr(0, kMaxStringLength);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void NormalizeCPUKernel::FreeBuffer() {
|
||||
for (size_t j = 0; j < normalized_strs.size(); ++j) {
|
||||
if (normalized_strs[j] != nullptr) {
|
||||
context_->allocator->Free(normalized_strs[j]);
|
||||
normalized_strs[j] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int NormalizeCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
|
||||
return ret;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
int string_num = lite::GetStringCount(input_tensor);
|
||||
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);
|
||||
|
||||
std::vector<lite::StringPack> out_string_pack;
|
||||
normalized_strs.resize(string_num, nullptr);
|
||||
|
||||
for (int i = 0; i < string_num; ++i) {
|
||||
auto chars = all_string_pack[i];
|
||||
std::string str(chars.data);
|
||||
std::string result = Normalize(str);
|
||||
int str_length = result.size() + 1;
|
||||
|
||||
char *normalized_str = nullptr;
|
||||
normalized_str = reinterpret_cast<char *>(context_->allocator->Malloc(sizeof(char) * str_length));
|
||||
if (normalized_str == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc data failed!";
|
||||
FreeBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
normalized_strs[i] = normalized_str;
|
||||
|
||||
memcpy(normalized_str, result.data(), str_length);
|
||||
out_string_pack.push_back({str_length, normalized_str});
|
||||
}
|
||||
if (string_num == 0) {
|
||||
out_string_pack.push_back({1, ""});
|
||||
}
|
||||
auto out_tensor = out_tensors_.at(0);
|
||||
WriteStringsToTensor(out_tensor, out_string_pack);
|
||||
FreeBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuNormalizeKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) NormalizeCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new NormalizeCPUKernel fail!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomNormalize, CpuNormalizeKernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "include/context.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class NormalizeCPUKernel : public LiteKernel {
|
||||
public:
|
||||
NormalizeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~NormalizeCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
std::string Trim(const std::string &str, const std::string &whitespace = " \t\n\v\f\r");
|
||||
std::string GlobalReplace(const std::string &str, const std::string ®, const std::string &replace);
|
||||
std::string Normalize(const std::string &str);
|
||||
std::vector<char *> normalized_strs;
|
||||
void FreeBuffer();
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_
|
|
@ -209,6 +209,7 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC
|
|||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "src/runtime/kernel/arm/fp32/skip_gram.h"
|
||||
#include "src/runtime/kernel/arm/string/normalize.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "nnacl/fp32/skip_gram.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::lite::StringPack;
|
||||
using mindspore::lite::Tensor;
|
||||
|
||||
class TestNormalize : public mindspore::CommonTest {
|
||||
public:
|
||||
TestNormalize() {}
|
||||
void NormalizeTestInit();
|
||||
|
||||
public:
|
||||
Tensor input_tensor_;
|
||||
Tensor output_tensor_;
|
||||
std::vector<Tensor *> inputs_{&input_tensor_};
|
||||
std::vector<Tensor *> outputs_{&output_tensor_};
|
||||
OpParameter parameter_ = {};
|
||||
lite::InnerContext ctx_ = lite::InnerContext();
|
||||
kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_CustomNormalize};
|
||||
kernel::KernelCreator creator_ = nullptr;
|
||||
kernel::LiteKernel *kernel_ = nullptr;
|
||||
};
|
||||
|
||||
void TestNormalize::NormalizeTestInit() {
|
||||
input_tensor_.set_data_type(kObjectTypeString);
|
||||
input_tensor_.SetFormat(schema::Format_NHWC);
|
||||
|
||||
std::vector<StringPack> str_pack;
|
||||
const char sentence1[] = " I don't know what happened\n";
|
||||
str_pack.push_back({static_cast<int>(strlen(sentence1) + 1), sentence1});
|
||||
const char sentence2[] = "She's not here when Alex arrived!!!";
|
||||
str_pack.push_back({static_cast<int>(strlen(sentence2) + 1), sentence2});
|
||||
mindspore::lite::WriteStringsToTensor(&input_tensor_, str_pack);
|
||||
|
||||
output_tensor_.set_data_type(kObjectTypeString);
|
||||
output_tensor_.SetFormat(schema::Format_NHWC);
|
||||
}
|
||||
|
||||
TEST_F(TestNormalize, TestSentence) {
|
||||
NormalizeTestInit();
|
||||
ASSERT_EQ(lite::RET_OK, ctx_.Init());
|
||||
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_);
|
||||
ASSERT_NE(creator_, nullptr);
|
||||
kernel_ = creator_(inputs_, outputs_, ¶meter_, &ctx_, desc_, nullptr);
|
||||
ASSERT_NE(kernel_, nullptr);
|
||||
auto ret = kernel_->Init();
|
||||
MS_ASSERT(ret == 0);
|
||||
ret = kernel_->Run();
|
||||
MS_ASSERT(ret == 0);
|
||||
|
||||
std::vector<StringPack> output = mindspore::lite::ParseTensorBuffer(outputs_[0]);
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
for (int j = 0; j < output[i].len; j++) {
|
||||
printf("%c", output[i].data[j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
input_tensor_.SetData(nullptr);
|
||||
output_tensor_.SetData(nullptr);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue