forked from mindspore-Ecosystem/mindspore
!17924 [MS][LITE][DEVELOP]custom kernel ut
Merge pull request !17924 from chenjianping/reg_kernel_dev
This commit is contained in:
commit
880435bd71
|
@ -125,9 +125,9 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the ordinary op type.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
|
||||
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::kernel::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
|
||||
} // namespace
|
||||
|
||||
/// \brief Defined registering macro to register custom op, which called by user directly.
|
||||
|
@ -135,9 +135,9 @@ class MS_API KernelInterfaceReg {
|
|||
/// \param[in] provider Define the identification of user.
|
||||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define the KernelInterface create function.
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, creator); \
|
||||
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::kernel::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, creator); \
|
||||
} // namespace
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -135,7 +135,8 @@ class MS_API KernelReg {
|
|||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, op_type, creator); \
|
||||
static mindspore::kernel::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
|
||||
op_type, creator); \
|
||||
} // namespace
|
||||
|
||||
/// \brief Defined registering macro to register custom op kernel, which called by user directly.
|
||||
|
@ -145,9 +146,10 @@ class MS_API KernelReg {
|
|||
/// \param[in] data_type Define kernel's input data type.
|
||||
/// \param[in] op_type Define the concrete type of a custom op.
|
||||
/// \param[in] creator Define a function pointer to create a kernel.
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, #op_type, creator); \
|
||||
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
|
||||
namespace { \
|
||||
static mindspore::kernel::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
|
||||
#op_type, creator); \
|
||||
} // namespace
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -300,6 +300,8 @@ set(TEST_SRC
|
|||
${TEST_DIR}/ut/src/dynamic_library_loader_test.cc
|
||||
${TEST_DIR}/ut/src/scheduler_test.cc
|
||||
${TEST_DIR}/ut/src/lite_mindrt_test.cc
|
||||
${TEST_DIR}/ut/src/registry/registry_test.cc
|
||||
${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "include/registry/kernel_interface.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
|
||||
using mindspore::kernel::Kernel;
|
||||
using mindspore::kernel::KernelInterface;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_PARAM_INVALID;
|
||||
using mindspore::schema::PrimitiveType_AddFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
const char *const kKeyName = "test_key";
|
||||
const char *const kTestData = "test_data";
|
||||
} // namespace
|
||||
|
||||
class TestData {
|
||||
public:
|
||||
static TestData *GetInstance() {
|
||||
static TestData instance;
|
||||
return &instance;
|
||||
}
|
||||
std::string data_;
|
||||
};
|
||||
|
||||
class TestCustomOp : public Kernel {
|
||||
public:
|
||||
TestCustomOp(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx)
|
||||
: Kernel(inputs, outputs, primitive, ctx) {}
|
||||
int Prepare() override { return 0; }
|
||||
|
||||
int Execute() override;
|
||||
|
||||
int ReSize() override { return 0; }
|
||||
|
||||
private:
|
||||
int PreProcess() {
|
||||
for (auto *output : outputs_) {
|
||||
auto data = output->MutableData();
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Get data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void GetAttrData() {
|
||||
auto prim = primitive_->value_as_Custom();
|
||||
if (prim->attr()->size() < 1) {
|
||||
return;
|
||||
}
|
||||
auto data_bytes = prim->attr()->Get(0)->data();
|
||||
auto data_size = data_bytes->size();
|
||||
char buf[100];
|
||||
for (size_t i = 0; i < data_size; ++i) {
|
||||
buf[i] = static_cast<char>(data_bytes->Get(i));
|
||||
}
|
||||
buf[data_size] = 0;
|
||||
TestData::GetInstance()->data_ = std::string(buf);
|
||||
}
|
||||
};
|
||||
|
||||
int TestCustomOp::Execute() {
|
||||
if (inputs_.size() != 2) {
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
PreProcess();
|
||||
GetAttrData();
|
||||
float *in0 = static_cast<float *>(inputs_[0]->data());
|
||||
float *in1 = static_cast<float *>(inputs_[1]->data());
|
||||
float *out = static_cast<float *>(outputs_[0]->data());
|
||||
auto num = outputs_[0]->ElementsNum();
|
||||
for (int i = 0; i < num; ++i) {
|
||||
out[i] = in0[i] + in1[i];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
class TestCustomOpInfer : public KernelInterface {
|
||||
public:
|
||||
TestCustomOpInfer() = default;
|
||||
~TestCustomOpInfer() = default;
|
||||
int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
outputs[0]->set_data_type(inputs[0]->data_type());
|
||||
outputs[0]->set_shape(inputs[0]->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<Kernel> TestCustomAddCreator(const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx) {
|
||||
return std::make_shared<TestCustomOp>(inputs, outputs, primitive, ctx);
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelInterface> CustomAddInferCreator() { return std::make_shared<TestCustomOpInfer>(); }
|
||||
} // namespace
|
||||
|
||||
REGISTER_CUSTOM_KERNEL(CPU, BuiltInTest, kNumberTypeFloat32, Add, TestCustomAddCreator)
|
||||
REGISTER_CUSTOM_KERNEL_INTERFACE(BuiltInTest, Add, CustomAddInferCreator)
|
||||
|
||||
class TestRegistryCustomOp : public mindspore::CommonTest {
|
||||
public:
|
||||
TestRegistryCustomOp() = default;
|
||||
};
|
||||
|
||||
TEST_F(TestRegistryCustomOp, TestCustomAdd) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {0, 1};
|
||||
node->outputIndex = {2};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_Custom;
|
||||
auto primitive = new schema::CustomT;
|
||||
primitive->type = "Add";
|
||||
auto attr = std::make_unique<schema::AttributeT>();
|
||||
attr->name = kKeyName;
|
||||
std::string test_data(kTestData);
|
||||
std::vector<uint8_t> attr_data(test_data.begin(), test_data.end());
|
||||
attr->data = attr_data;
|
||||
primitive->attr.emplace_back(std::move(attr));
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "Add";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
meta_graph->inputIndex = {0, 1};
|
||||
meta_graph->outputIndex = {2};
|
||||
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = lite::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 28, 28, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
auto weight = std::make_unique<schema::TensorT>();
|
||||
weight->nodeType = lite::NodeType_ValueNode;
|
||||
weight->format = schema::Format_NHWC;
|
||||
weight->dataType = TypeId::kNumberTypeFloat32;
|
||||
weight->dims = {1, 28, 28, 3};
|
||||
|
||||
weight->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(weight));
|
||||
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = lite::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
|
||||
|
||||
auto model = lite::Model::Import(content, size);
|
||||
ASSERT_NE(nullptr, model);
|
||||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
auto &device_list = context->device_list_;
|
||||
std::shared_ptr<DefaultAllocator> allocator = std::make_shared<DefaultAllocator>();
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}, "BuiltInTest", "CPU", allocator};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto inputs = session->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
auto inTensor = inputs.front();
|
||||
ASSERT_NE(nullptr, inTensor);
|
||||
float *in0_data = static_cast<float *>(inTensor->MutableData());
|
||||
in0_data[0] = 10.0f;
|
||||
auto inTensor1 = inputs.back();
|
||||
ASSERT_NE(nullptr, inTensor1);
|
||||
float *in1_data = static_cast<float *>(inTensor1->MutableData());
|
||||
in1_data[0] = 20.0f;
|
||||
ret = session->RunGraph();
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputs();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = outputs.begin()->second;
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
ASSERT_EQ(30.0f, outData[0]);
|
||||
ASSERT_EQ(TestData::GetInstance()->data_, kTestData);
|
||||
MS_LOG(INFO) << "Register add op test pass.";
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "include/registry/kernel_interface.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
|
||||
using mindspore::kernel::Kernel;
|
||||
using mindspore::kernel::KernelInterface;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_PARAM_INVALID;
|
||||
using mindspore::schema::PrimitiveType_AddFusion;
|
||||
|
||||
namespace mindspore {
|
||||
class TestCustomAdd : public Kernel {
|
||||
public:
|
||||
TestCustomAdd(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx)
|
||||
: Kernel(inputs, outputs, primitive, ctx) {}
|
||||
int Prepare() override { return 0; }
|
||||
|
||||
int Execute() override;
|
||||
|
||||
int ReSize() override { return 0; }
|
||||
|
||||
private:
|
||||
int PreProcess() {
|
||||
for (auto *output : outputs_) {
|
||||
auto data = output->MutableData();
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Get data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
};
|
||||
|
||||
int TestCustomAdd::Execute() {
|
||||
if (inputs_.size() != 2) {
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
PreProcess();
|
||||
float *in0 = static_cast<float *>(inputs_[0]->data());
|
||||
float *in1 = static_cast<float *>(inputs_[1]->data());
|
||||
float *out = static_cast<float *>(outputs_[0]->data());
|
||||
auto num = outputs_[0]->ElementsNum();
|
||||
for (int i = 0; i < num; ++i) {
|
||||
out[i] = in0[i] + in1[i];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
class TestCustomAddInfer : public KernelInterface {
|
||||
public:
|
||||
TestCustomAddInfer() = default;
|
||||
~TestCustomAddInfer() = default;
|
||||
int Infer(const std::vector<tensor::MSTensor *> &inputs, const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive) override {
|
||||
outputs[0]->set_data_type(inputs[0]->data_type());
|
||||
outputs[0]->set_shape(inputs[0]->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<Kernel> TestCustomAddCreator(const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const schema::Primitive *primitive, const lite::Context *ctx) {
|
||||
return std::make_shared<TestCustomAdd>(inputs, outputs, primitive, ctx);
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelInterface> CustomAddInferCreator() { return std::make_shared<TestCustomAddInfer>(); }
|
||||
} // namespace
|
||||
|
||||
REGISTER_KERNEL(CPU, BuiltInTest, kNumberTypeFloat32, PrimitiveType_AddFusion, TestCustomAddCreator)
|
||||
REGISTER_KERNEL_INTERFACE(BuiltInTest, PrimitiveType_AddFusion, CustomAddInferCreator)
|
||||
|
||||
class TestRegistry : public mindspore::CommonTest {
|
||||
public:
|
||||
TestRegistry() = default;
|
||||
};
|
||||
|
||||
TEST_F(TestRegistry, TestAdd) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {0, 1};
|
||||
node->outputIndex = {2};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_AddFusion;
|
||||
auto primitive = new schema::AddFusionT;
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "Add";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
meta_graph->inputIndex = {0, 1};
|
||||
meta_graph->outputIndex = {2};
|
||||
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = lite::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 28, 28, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
auto weight = std::make_unique<schema::TensorT>();
|
||||
weight->nodeType = lite::NodeType_ValueNode;
|
||||
weight->format = schema::Format_NHWC;
|
||||
weight->dataType = TypeId::kNumberTypeFloat32;
|
||||
weight->dims = {1, 28, 28, 3};
|
||||
|
||||
weight->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(weight));
|
||||
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = lite::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
|
||||
|
||||
auto model = lite::Model::Import(content, size);
|
||||
ASSERT_NE(nullptr, model);
|
||||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
auto &device_list = context->device_list_;
|
||||
std::shared_ptr<DefaultAllocator> allocator = std::make_shared<DefaultAllocator>();
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}, "BuiltInTest", "CPU", allocator};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto inputs = session->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
auto inTensor = inputs.front();
|
||||
ASSERT_NE(nullptr, inTensor);
|
||||
float *in0_data = static_cast<float *>(inTensor->MutableData());
|
||||
in0_data[0] = 10.0f;
|
||||
auto inTensor1 = inputs.back();
|
||||
ASSERT_NE(nullptr, inTensor1);
|
||||
float *in1_data = static_cast<float *>(inTensor1->MutableData());
|
||||
in1_data[0] = 20.0f;
|
||||
ret = session->RunGraph();
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputs();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = outputs.begin()->second;
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
ASSERT_EQ(30.0f, outData[0]);
|
||||
MS_LOG(INFO) << "Register add op test pass.";
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue