!9446 [MS][LITE]assert op

From: @YeFeng_24
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-04 16:44:02 +08:00 committed by Gitee
commit 053749239a
7 changed files with 381 additions and 0 deletions

View File

@ -0,0 +1,72 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/assert_op.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AssertOP::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Assert;
}
if (this->primitive_->value.type != schema::PrimitiveType_Assert) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::AssertT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else
int AssertOP::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Assert();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Assert return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAssert(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assert, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AssertCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<AssertOP>(primitive); }
Registry AssertRegistry(schema::PrimitiveType_Assert, AssertCreator);
#endif
int AssertOP::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class AssertOP : public PrimitiveC {
public:
AssertOP() = default;
~AssertOP() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(AssertOP, PrimitiveC);
explicit AssertOP(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_SRC_OPS_ASSERT_OP_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/assert_op.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAssertParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *assert_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (assert_parameter == nullptr) {
MS_LOG(ERROR) << "malloc AssertParameter failed.";
return nullptr;
}
memset(assert_parameter, 0, sizeof(OpParameter));
assert_parameter->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(assert_parameter);
}
Registry AssertParameterRegistry(schema::PrimitiveType_Assert, PopulateAssertParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/assert.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Assert;
namespace mindspore::kernel {
int AssertCPUKernel::Init() { return RET_OK; }
int AssertCPUKernel::ReSize() { return RET_OK; }
int AssertCPUKernel::Run() {
auto cond = reinterpret_cast<bool *>(in_tensors_.front()->data_c());
if (*cond) {
return RET_OK;
} else {
for (size_t i = 1; i < in_tensors_.size(); i++) {
MS_LOG(ERROR) << in_tensors_.at(i)->ToString();
}
return RET_ERROR;
}
}
kernel::LiteKernel *CpuAssertKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "ctx is nullptr";
free(parameter);
return nullptr;
}
MS_ASSERT(desc.type == PrimitiveType_Assert);
auto *kernel = new (std::nothrow) AssertCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assert, CpuAssertKernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Assert, CpuAssertKernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
typedef struct AssertParameter {
OpParameter op_parameter_;
} AssertParameter;
class AssertCPUKernel : public LiteKernel {
public:
AssertCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
assert_param_ = reinterpret_cast<AssertParameter *>(op_parameter_);
}
~AssertCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
private:
AssertParameter *assert_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ASSERT_H_

View File

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_assert_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF AssertParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::AssertT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) {
MS_LOG(ERROR) << "The keep_dims attr should be specified";
return RET_ERROR;
}
attr->summarize = attr_value.i();
primitive->value.type = schema::PrimitiveType_Assert;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return status;
}
TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFAssertParser : public TFNodeParser {
public:
TFAssertParser() = default;
~TFAssertParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_