forked from mindspore-Ecosystem/mindspore
free meta_graph after compile graph
This commit is contained in:
parent
9d8fb786cb
commit
194253635d
mindspore/lite/src
common
lite_session.cclite_session.hmodel.ccops
abs.ccabs.hactivation.ccactivation.hactivation_grad.ccactivation_grad.hadd.ccadd.haddn.ccaddn.hargmax.ccargmax.hargmin.ccargmin.harithmetic.harithmetic_self.hbatch_norm.ccbatch_norm.hbatch_to_space.ccbatch_to_space.hbias_add.ccbias_add.hbias_grad.ccbias_grad.hbn_grad_input.ccbn_grad_input.hbroadcast_to.ccbroadcast_to.hcast.cccast.hceil.hclip.ccclip.hconcat.ccconcat.hconstant_of_shape.ccconstant_of_shape.hconv2d.ccconv2d.hconv2d_grad_filter.ccconv2d_grad_filter.hconv2d_grad_input.ccconv2d_grad_input.hcos.cccos.hcrop.cccrop.hdeconv2d.ccdeconv2d.hdedepthwise_conv2d.ccdedepthwise_conv2d.hdepth_to_space.ccdepth_to_space.hdepthwise_conv2d.ccdepthwise_conv2d.hdequant.hdetection_post_process.ccdetection_post_process.hdiv.ccdiv.hdropout.ccdropout.heltwise.cceltwise.helu.ccelu.hembedding_lookup.ccembedding_lookup.hembedding_lookup_sparse.ccembedding_lookup_sparse.hequal.ccequal.hexp.ccexp.hexpand_dims.ccexpand_dims.hfake_quant_with_min_max_vars.ccfake_quant_with_min_max_vars.hfill.ccfill.hflatten.ccflatten.hfloor.ccfloor.hfloor_div.ccfloor_div.hfloor_mod.ccfloor_mod.hfull_connection.ccfull_connection.hfused_batchnorm.ccfused_batchnorm.hgather.ccgather.hgather_nd.cc
|
@ -61,5 +61,22 @@ std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph) {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx) {
|
||||
std::vector<size_t> post_node_idxes;
|
||||
for (size_t i = 0; i < graph.nodes()->size(); i++) {
|
||||
auto node = graph.nodes()->GetAs<schema::CNode>(i);
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node_input_idxes = node->inputIndex();
|
||||
auto is_contain = std::any_of(node_input_idxes->begin(), node_input_idxes->end(),
|
||||
[&](const uint32_t &node_input_idx) { return node_input_idx == tensor_idx; });
|
||||
if (is_contain) {
|
||||
post_node_idxes.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return post_node_idxes;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,215 +34,8 @@ std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph);
|
|||
|
||||
std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph);
|
||||
|
||||
class OpNode {
|
||||
public:
|
||||
explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {}
|
||||
NODE_ID ID() { return id; };
|
||||
void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); }
|
||||
void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); }
|
||||
std::unordered_set<NODE_ID> GetAllInEdges() { return inEdges; }
|
||||
std::unordered_set<NODE_ID> GetAllOutEdges() { return outEdges; }
|
||||
|
||||
protected:
|
||||
NODE_ID id;
|
||||
std::unordered_set<NODE_ID> inEdges;
|
||||
std::unordered_set<NODE_ID> outEdges;
|
||||
};
|
||||
|
||||
|
||||
template <typename NODE_T>
|
||||
class OpGraph {
|
||||
public:
|
||||
OpGraph() {}
|
||||
|
||||
~OpGraph();
|
||||
|
||||
int Build(const schema::MetaGraph *subGraphDef);
|
||||
NODE_T *GetNode(NODE_ID nodeId);
|
||||
NODE_T *AddNode(NODE_ID nodeId);
|
||||
std::unordered_set<NODE_T *> GetInputNode();
|
||||
std::unordered_set<NODE_T *> GetOutputNode();
|
||||
|
||||
void AddNodes(std::vector<NODE_T *> addNodes);
|
||||
void DeleteNodes(std::vector<NODE_T *> deleteNodes);
|
||||
|
||||
void AddEdge(NODE_ID nodeId);
|
||||
int AddEdge(NODE_ID srcId, NODE_ID dstId);
|
||||
int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *opDefs);
|
||||
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> GetDepends();
|
||||
|
||||
protected:
|
||||
std::unordered_map<NODE_ID, NODE_T *> nodes;
|
||||
};
|
||||
|
||||
template <typename NODE_T>
|
||||
int OpGraph<NODE_T>::Build(const schema::MetaGraph *subGraphDef) {
|
||||
if (subGraphDef == nullptr) {
|
||||
// MS_LOGE("subGraphDef is nullptr");
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto opDefs = subGraphDef->nodes();
|
||||
|
||||
uint32_t opCount = opDefs->size();
|
||||
for (uint32_t i = 0; i < opCount; i++) {
|
||||
auto opDef = opDefs->GetAs<schema::CNode>(i);
|
||||
auto node = AddNode(std::string(opDef->name()->c_str()));
|
||||
if (node == nullptr) {
|
||||
// MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str());
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = AddEdge(opDef, opDefs);
|
||||
if (ret != RET_OK) {
|
||||
// MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
template <typename NODE_T>
|
||||
int OpGraph<NODE_T>::AddEdge(const schema::CNode *srcNodeDef,
|
||||
const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *nodeDefs) {
|
||||
MS_ASSERT(srcNodeDef != nullptr);
|
||||
MS_ASSERT(nodeDefs != nullptr);
|
||||
NODE_ID srcId = std::string(srcNodeDef->name()->c_str());
|
||||
uint32_t opCount = nodeDefs->size();
|
||||
// for single op condition
|
||||
AddNode(srcId);
|
||||
for (auto index : *(srcNodeDef->outputIndex())) {
|
||||
for (uint32_t i = 0; i < opCount; i++) {
|
||||
auto dstNodeDef = nodeDefs->GetAs<schema::CNode>(i);
|
||||
bool find = false;
|
||||
auto inputIndex = dstNodeDef->inputIndex();
|
||||
if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
|
||||
find = true;
|
||||
}
|
||||
|
||||
if (!find) {
|
||||
continue;
|
||||
}
|
||||
NODE_ID dstId = std::string(dstNodeDef->name()->c_str());
|
||||
auto ret = AddEdge(srcId, dstId);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
int OpGraph<NODE_T>::AddEdge(NODE_ID srcId, NODE_ID dstId) {
|
||||
auto srcNode = AddNode(srcId);
|
||||
if (srcNode == nullptr) {
|
||||
// MS_LOGE("add srcNode failed");
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dstNode = AddNode(dstId);
|
||||
if (dstNode == nullptr) {
|
||||
// MS_LOGE("add dstNode failed");
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
srcNode->AddOutEdge(dstNode);
|
||||
|
||||
dstNode->AddInEdge(srcNode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
NODE_T *OpGraph<NODE_T>::GetNode(NODE_ID nodeId) {
|
||||
auto node = nodes.find(nodeId);
|
||||
if (node == nodes.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return node->second;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
NODE_T *OpGraph<NODE_T>::AddNode(NODE_ID nodeId) {
|
||||
auto node = GetNode(nodeId);
|
||||
if (node != nullptr) {
|
||||
return node;
|
||||
}
|
||||
node = new (std::nothrow) NODE_T(nodeId);
|
||||
if (node == nullptr) {
|
||||
// MS_LOGE("new node failed");
|
||||
return nullptr;
|
||||
}
|
||||
nodes[nodeId] = node;
|
||||
return node;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
void OpGraph<NODE_T>::AddNodes(std::vector<NODE_T *> addNodes) {
|
||||
for (auto node : addNodes) {
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
nodes[node->ID()] = node;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
void OpGraph<NODE_T>::DeleteNodes(std::vector<NODE_T *> deleteNodes) {
|
||||
for (auto deletenode : deleteNodes) {
|
||||
if (deletenode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node = GetNode(deletenode->ID());
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
nodes.erase(deletenode->ID());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetInputNode() {
|
||||
std::unordered_set<NODE_T *> inputNodes;
|
||||
for (const auto &iter : nodes) {
|
||||
auto node = iter.second;
|
||||
if (node->GetAllInEdges().empty()) {
|
||||
inputNodes.insert(node);
|
||||
}
|
||||
}
|
||||
return inputNodes;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetOutputNode() {
|
||||
std::unordered_set<NODE_T *> outputNodes;
|
||||
for (const auto &iter : nodes) {
|
||||
auto node = iter.second;
|
||||
if (node->GetAllOutEdges().empty()) {
|
||||
outputNodes.insert(node);
|
||||
}
|
||||
}
|
||||
return outputNodes;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> OpGraph<NODE_T>::GetDepends() {
|
||||
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> depends;
|
||||
for (auto nodeIter : nodes) {
|
||||
depends[nodeIter.second] = nodeIter.second->GetAllInEdges();
|
||||
}
|
||||
return depends;
|
||||
}
|
||||
|
||||
template <typename NODE_T>
|
||||
OpGraph<NODE_T>::~OpGraph() {
|
||||
for (auto iter : nodes) {
|
||||
delete iter.second;
|
||||
}
|
||||
nodes.clear();
|
||||
}
|
||||
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_
|
||||
|
||||
|
|
|
@ -32,10 +32,29 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static std::vector<schema::PrimitiveType> packed_op = {
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_MatMul};
|
||||
|
||||
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
|
||||
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
|
||||
MS_ASSERT(nullptr != model);
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(nullptr != meta_graph);
|
||||
auto post_node_idxes = GetLinkedPostNodeIdx(*meta_graph, tensor_idx);
|
||||
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
|
||||
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(post_node_idx);
|
||||
MS_ASSERT(cNode != nullptr);
|
||||
return IsContain(packed_op, cNode->primitive()->value_type());
|
||||
});
|
||||
}
|
||||
|
||||
int LiteSession::ConvertTensors(const lite::Model *model) {
|
||||
MS_ASSERT(nullptr != model);
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(nullptr != meta_graph);
|
||||
copyed_tensor_idxes_.clear();
|
||||
uint32_t tensorCount = meta_graph->allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensorCount; i++) {
|
||||
auto *srcTensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
||||
|
@ -54,16 +73,30 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
}
|
||||
}
|
||||
int dataType = srcTensor->dataType();
|
||||
auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
|
||||
auto *dstTensor =
|
||||
new (std::nothrow) tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
|
||||
if (dstTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new " << i << "th tensor failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr &&
|
||||
srcTensor->data()->size() > 0) {
|
||||
if (shape.empty()) {
|
||||
shape.push_back(1);
|
||||
dstTensor->set_shape(shape);
|
||||
}
|
||||
MS_ASSERT(dstTensor != nullptr);
|
||||
MS_ASSERT(dstTensor->Size() == srcTensor->data()->size());
|
||||
// no copy data, do copy when call LiteKernel::Init
|
||||
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
|
||||
if (WeightTensorNeedCopy(model, i)) {
|
||||
auto ret = dstTensor->MallocData();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Malloc data for " << i << "th tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(dstTensor->Data(), srcTensor->data()->data(), dstTensor->Size());
|
||||
copyed_tensor_idxes_.emplace_back(i);
|
||||
} else {
|
||||
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
|
||||
}
|
||||
}
|
||||
auto quant_params = srcTensor->quantParams();
|
||||
if (quant_params != nullptr) {
|
||||
|
@ -74,7 +107,6 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
dstTensor->AddQuantParam(quant_arg);
|
||||
}
|
||||
}
|
||||
|
||||
this->tensors_.emplace_back(dstTensor);
|
||||
}
|
||||
|
||||
|
@ -240,6 +272,7 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
}
|
||||
|
||||
executor->Prepare(this->kernels_);
|
||||
model->FreeMetaGraph();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -277,7 +310,10 @@ int LiteSession::Init(Context *context) {
|
|||
}
|
||||
#endif
|
||||
executor = new Executor();
|
||||
MS_ASSERT(nullptr != executor);
|
||||
if (nullptr == executor) {
|
||||
MS_LOG(ERROR) << "new Executor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -288,9 +324,12 @@ void LiteSession::BindThread(bool if_bind) {
|
|||
}
|
||||
|
||||
LiteSession::~LiteSession() {
|
||||
for (auto *tensor : tensors_) {
|
||||
// weight data can not be to free, we will free weight data when freeing meta_graph
|
||||
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor)) {
|
||||
for (size_t i = 0; i < tensors_.size(); i++) {
|
||||
auto *tensor = tensors_.at(i);
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
// data of weight tensor of node in packed_op can not be to free, we will free weight data when freeing meta_graph
|
||||
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor) &&
|
||||
!IsContain(copyed_tensor_idxes_, i)) {
|
||||
tensor->SetData(nullptr);
|
||||
}
|
||||
delete tensor;
|
||||
|
|
|
@ -87,6 +87,7 @@ class LiteSession : public session::LiteSession {
|
|||
Context *context_ = nullptr;
|
||||
std::vector<kernel::LiteKernel *> kernels_;
|
||||
std::vector<tensor::Tensor *> tensors_;
|
||||
std::vector<size_t> copyed_tensor_idxes_;
|
||||
// graph input tensors
|
||||
std::vector<tensor::Tensor *> inputs_;
|
||||
// graph output tensors
|
||||
|
|
|
@ -135,7 +135,7 @@ mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
|
|||
|
||||
void Model::FreeMetaGraph() {
|
||||
MS_ASSERT(nullptr != model_impl_);
|
||||
return model_impl_->FreeMetaGraph();
|
||||
model_impl_->FreeMetaGraph();
|
||||
}
|
||||
|
||||
const schema::MetaGraph *Model::GetMetaGraph() const {
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/abs.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateAbs(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Abs, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -32,27 +32,9 @@ class Abs : public ArithmeticSelf {
|
|||
Abs() = default;
|
||||
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
Abs() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateAbs(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Abs, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -55,7 +55,19 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
|
||||
int Activation::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Activation();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Activation return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateActivation(*fbb, attr->type(), attr->alpha());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Activation, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
|
||||
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
|
||||
#endif
|
||||
|
|
|
@ -30,34 +30,13 @@ class Activation : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Activation, PrimitiveC);
|
||||
Activation() = default;
|
||||
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetType(int type);
|
||||
void SetAlpha(float alpha);
|
||||
#else
|
||||
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Activation() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Activation();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateActivation(fbb, attr->type(), attr->alpha());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Activation, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetType() const;
|
||||
float GetAlpha() const;
|
||||
|
|
|
@ -26,7 +26,19 @@ void ActivationGrad::SetType(int type) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_ActivationGrad();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_ActivationGrad return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateActivationGrad(*fbb, attr->type());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -33,30 +33,9 @@ class ActivationGrad : public PrimitiveC {
|
|||
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetType(int type);
|
||||
#else
|
||||
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ActivationGrad() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_ActivationGrad();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateActivationGrad(fbb, attr->type());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetType() const;
|
||||
};
|
||||
|
|
|
@ -50,7 +50,19 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Add();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Add return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAdd(*fbb, attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Add, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -31,33 +31,12 @@ class Add : public Arithmetic {
|
|||
MS_DECLARE_PARENT(Add, Arithmetic);
|
||||
Add() = default;
|
||||
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
Add() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Add();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateAdd(fbb, attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Add, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,19 @@ int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; }
|
|||
void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
|
||||
|
||||
#else
|
||||
|
||||
int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_AddN();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_AddN return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAddN(*fbb, attr->N());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AddN, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -33,30 +33,9 @@ class AddN : public PrimitiveC {
|
|||
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetN(int n);
|
||||
#else
|
||||
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
AddN() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_AddN();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateAddN(fbb, attr->N());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_AddN, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetN() const;
|
||||
|
|
|
@ -32,7 +32,20 @@ void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->k
|
|||
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_ArgMax();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_ArgMax return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset =
|
||||
schema::CreateArgMax(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMax, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); }
|
||||
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); }
|
||||
int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); }
|
||||
|
|
|
@ -37,31 +37,9 @@ class ArgMax : public PrimitiveC {
|
|||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
#else
|
||||
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ArgMax() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_ArgMax();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateArgMax(fbb, attr->axis(), attr->outMaxValue(),
|
||||
attr->topK(), attr->keepDims(), attr->axisType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMax, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
|
|
|
@ -32,7 +32,20 @@ void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->k
|
|||
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }
|
||||
|
||||
#else
|
||||
|
||||
int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_ArgMin();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_ArgMin return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset =
|
||||
schema::CreateArgMin(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMin, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); }
|
||||
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); }
|
||||
int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); }
|
||||
|
|
|
@ -37,31 +37,9 @@ class ArgMin : public PrimitiveC {
|
|||
void SetKeepDims(bool keep_dims);
|
||||
void SetAxisType(int axis_type);
|
||||
#else
|
||||
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ArgMin() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_ArgMin();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateArgMin(fbb, attr->axis(), attr->outMaxValue(),
|
||||
attr->topK(), attr->keepDims(), attr->axisType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMin, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
|
|
|
@ -32,7 +32,11 @@ class Arithmetic : public PrimitiveC {
|
|||
Arithmetic() = default;
|
||||
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
// explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Arithmetic() = default;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool Broadcasting() { return this->broadcasting_; }
|
||||
|
|
|
@ -29,7 +29,11 @@ class ArithmeticSelf : public PrimitiveC {
|
|||
ArithmeticSelf() = default;
|
||||
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
#else
|
||||
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
// explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ArithmeticSelf() = default;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
};
|
||||
|
|
|
@ -49,7 +49,14 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateBatchNorm(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -31,30 +31,12 @@ class BatchNorm : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
|
||||
BatchNorm() = default;
|
||||
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetEpsilon(float epsilon);
|
||||
#else
|
||||
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BatchNorm() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateBatchNorm(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
};
|
||||
|
|
|
@ -32,7 +32,31 @@ void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
|
|||
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; }
|
||||
|
||||
#else
|
||||
|
||||
int BatchToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BatchToSpace();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BatchToSpace return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> blockShape;
|
||||
if (attr->blockShape() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
|
||||
blockShape.push_back(attr->blockShape()->data()[i]);
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> crops;
|
||||
if (attr->crops() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
|
||||
crops.push_back(attr->crops()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateBatchToSpaceDirect(*fbb, &blockShape, &crops);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> BatchToSpace::GetBlockShape() const {
|
||||
auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -35,39 +35,9 @@ class BatchToSpace : public PrimitiveC {
|
|||
void SetBlockShape(const std::vector<int> &block_shape);
|
||||
void SetCrops(const std::vector<int> &crops);
|
||||
#else
|
||||
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BatchToSpace() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_BatchToSpace();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto blockShape = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
|
||||
blockShape->push_back(attr->blockShape()->data()[i]);
|
||||
}
|
||||
auto crops = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
|
||||
crops->push_back(attr->crops()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateBatchToSpaceDirect(fbb, blockShape.release(), crops.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetBlockShape() const;
|
||||
|
|
|
@ -54,7 +54,25 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BiasAdd();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BiasAdd return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> axis;
|
||||
if (attr->axis() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
|
||||
axis.push_back(attr->axis()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> BiasAdd::GetAxis() const {
|
||||
auto fb_vector = this->primitive_->value_as_BiasAdd()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -32,38 +32,12 @@ class BiasAdd : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
|
||||
BiasAdd() = default;
|
||||
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
#else
|
||||
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BiasAdd() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_BiasAdd();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto axis = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
|
||||
axis->push_back(attr->axis()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateBiasAddDirect(fbb, axis.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,25 @@ std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBi
|
|||
void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
||||
int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BiasGrad();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BiasGrad return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> axis;
|
||||
if (attr->axis() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
|
||||
axis.push_back(attr->axis()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateBiasGradDirect(*fbb, &axis);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasGrad, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> BiasGrad::GetAxis() const {
|
||||
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -35,35 +35,9 @@ class BiasGrad : public PrimitiveC {
|
|||
void SetAxis(const std::vector<int> &axis);
|
||||
|
||||
#else
|
||||
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BiasGrad() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_BiasGrad();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto axis = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
|
||||
axis->push_back(attr->axis()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateBiasGradDirect(fbb, axis.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasGrad, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
};
|
||||
|
|
|
@ -26,7 +26,19 @@ void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->e
|
|||
void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; }
|
||||
|
||||
#else
|
||||
|
||||
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BNGradInput();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BNGradInput return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->channels());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
|
||||
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
|
||||
|
||||
|
|
|
@ -34,30 +34,9 @@ class BNGradInput : public PrimitiveC {
|
|||
void SetEps(float eps);
|
||||
void SetChannels(int channels);
|
||||
#else
|
||||
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BNGradInput() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_BNGradInput();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateBNGradInput(fbb, attr->eps(), attr->channels());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetEps() const;
|
||||
int GetChannels() const;
|
||||
|
|
|
@ -26,7 +26,25 @@ void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int BroadcastTo::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_BroadcastTo();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_BroadcastTo return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> dst_shape;
|
||||
if (attr->dst_shape() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->dst_shape()->size()); i++) {
|
||||
dst_shape.push_back(attr->dst_shape()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateBroadcastToDirect(*fbb, &dst_shape);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BroadcastTo, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> BroadcastTo::GetDstShape() const {
|
||||
auto fb_vector = this->primitive_->value_as_BroadcastTo()->dst_shape();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -35,35 +35,9 @@ class BroadcastTo : public PrimitiveC {
|
|||
void SetDstShape(const std::vector<int> &dst_shape);
|
||||
|
||||
#else
|
||||
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
BroadcastTo() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_BroadcastTo();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto dst_shape = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->dst_shape()->size()); i++) {
|
||||
dst_shape->push_back(attr->dst_shape()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateBroadcastToDirect(fbb, dst_shape.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BroadcastTo, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDstShape() const;
|
||||
|
|
|
@ -26,7 +26,19 @@ void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t;
|
|||
void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; }
|
||||
|
||||
#else
|
||||
|
||||
int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Cast();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Cast return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateCast(*fbb, attr->srcT(), attr->dstT());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cast, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
|
||||
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
|
||||
|
||||
|
|
|
@ -34,30 +34,9 @@ class Cast : public PrimitiveC {
|
|||
void SetSrcT(int src_t);
|
||||
void SetDstT(int dst_t);
|
||||
#else
|
||||
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Cast() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Cast();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateCast(fbb, attr->srcT(), attr->dstT());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cast, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetSrcT() const;
|
||||
|
|
|
@ -32,26 +32,15 @@ class Ceil : public ArithmeticSelf {
|
|||
Ceil() = default;
|
||||
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
Ceil() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateCeil(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Ceil, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateCeil(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Ceil, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
|
|
@ -26,7 +26,19 @@ void Clip::SetMax(float max) { this->primitive_->value.AsClip()->max = max; }
|
|||
void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }
|
||||
|
||||
#else
|
||||
|
||||
int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Clip();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Clip return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateClip(*fbb, attr->max(), attr->min());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Clip, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
|
||||
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
|
||||
|
||||
|
|
|
@ -34,30 +34,9 @@ class Clip : public PrimitiveC {
|
|||
void SetMax(float max);
|
||||
void SetMin(float min);
|
||||
#else
|
||||
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Clip() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Clip();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateClip(fbb, attr->max(), attr->min());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Clip, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetMax() const;
|
||||
float GetMin() const;
|
||||
|
|
|
@ -60,7 +60,19 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Concat();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Concat return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
|
||||
|
||||
|
|
|
@ -31,34 +31,13 @@ class Concat : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Concat, PrimitiveC);
|
||||
Concat() = default;
|
||||
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAxis(int axis);
|
||||
void SetN(int n);
|
||||
#else
|
||||
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Concat() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Concat();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateConcat(fbb, attr->axis(), attr->n());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Concat, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
|
|
|
@ -30,7 +30,19 @@ float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConst
|
|||
void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; }
|
||||
|
||||
#else
|
||||
|
||||
int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_ConstantOfShape();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -33,30 +33,9 @@ class ConstantOfShape : public PrimitiveC {
|
|||
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetValue(float value);
|
||||
#else
|
||||
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ConstantOfShape() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_ConstantOfShape();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateConstantOfShape(fbb, attr->value());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetValue() const;
|
||||
|
|
|
@ -338,7 +338,23 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
}
|
||||
|
||||
#else
|
||||
int Conv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Conv2D();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Conv2D return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateConv2D(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2D, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Conv2D::GetFormat() const { return this->primitive_->value_as_Conv2D()->format(); }
|
||||
int Conv2D::GetGroup() const { return this->primitive_->value_as_Conv2D()->group(); }
|
||||
int Conv2D::GetChannelIn() const { return this->primitive_->value_as_Conv2D()->channelIn(); }
|
||||
|
|
|
@ -34,7 +34,7 @@ class Conv2D : public PrimitiveC {
|
|||
Conv2D() = default;
|
||||
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
|
@ -63,34 +63,9 @@ class Conv2D : public PrimitiveC {
|
|||
#else
|
||||
|
||||
public:
|
||||
explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Conv2D() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Conv2D();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2D, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
|
|
@ -68,7 +68,22 @@ void Conv2DGradFilter::SetActivationType(int activation_type) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Conv2DGradFilter();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Conv2DGradFilter return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConv2DGradFilter(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Conv2DGradFilter::GetFormat() const { return this->primitive_->value_as_Conv2DGradFilter()->format(); }
|
||||
int Conv2DGradFilter::GetGroup() const { return this->primitive_->value_as_Conv2DGradFilter()->group(); }
|
||||
int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradFilter()->channelIn(); }
|
||||
|
|
|
@ -49,35 +49,9 @@ class Conv2DGradFilter : public PrimitiveC {
|
|||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Conv2DGradFilter() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Conv2DGradFilter();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateConv2DGradFilter(fbb, attr->format(), attr->group(),
|
||||
attr->channelIn(), attr->channelOut(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
|
|
|
@ -66,7 +66,22 @@ void Conv2DGradInput::SetActivationType(int activation_type) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Conv2DGradInput();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Conv2DGradInput return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConv2DGradInput(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Conv2DGradInput::GetFormat() const { return this->primitive_->value_as_Conv2DGradInput()->format(); }
|
||||
int Conv2DGradInput::GetGroup() const { return this->primitive_->value_as_Conv2DGradInput()->group(); }
|
||||
int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradInput()->channelIn(); }
|
||||
|
|
|
@ -49,35 +49,9 @@ class Conv2DGradInput : public PrimitiveC {
|
|||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Conv2DGradInput() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Conv2DGradInput();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateConv2DGradInput(fbb, attr->format(), attr->group(),
|
||||
attr->channelIn(), attr->channelOut(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/cos.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateCos(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cos, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -21,7 +21,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -31,27 +31,9 @@ class Cos : public ArithmeticSelf {
|
|||
Cos() = default;
|
||||
explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
Cos() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateCos(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cos, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -26,7 +26,25 @@ void Crop::SetAxis(int64_t axis) { this->primitive_->value.AsCrop()->axis = axis
|
|||
void Crop::SetOffsets(const std::vector<int64_t> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; }
|
||||
|
||||
#else
|
||||
|
||||
int Crop::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Crop();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Crop return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int64_t> offsets;
|
||||
if (attr->offsets() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->offsets()->size()); i++) {
|
||||
offsets.push_back(attr->offsets()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateCropDirect(*fbb, attr->axis(), &offsets);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Crop, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int64_t Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); }
|
||||
std::vector<int64_t> Crop::GetOffsets() const {
|
||||
auto fb_vector = this->primitive_->value_as_Crop()->offsets();
|
||||
|
|
|
@ -35,35 +35,9 @@ class Crop : public PrimitiveC {
|
|||
void SetAxis(int64_t axis);
|
||||
void SetOffsets(const std::vector<int64_t> &offsets);
|
||||
#else
|
||||
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Crop() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Crop();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto offsets = std::make_unique<std::vector<int64_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->offsets()->size()); i++) {
|
||||
offsets->push_back(attr->offsets()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateCropDirect(fbb, attr->axis(), offsets.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Crop, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int64_t GetAxis() const;
|
||||
|
|
|
@ -58,7 +58,22 @@ void DeConv2D::SetActivationType(int activation_type) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int DeConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_DeConv2D();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_DeConv2D return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDeConv2D(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeConv2D, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); }
|
||||
int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); }
|
||||
int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); }
|
||||
|
|
|
@ -49,34 +49,9 @@ class DeConv2D : public PrimitiveC {
|
|||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
DeConv2D() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_DeConv2D();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDeConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeConv2D, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
|
|
|
@ -70,7 +70,24 @@ void DeDepthwiseConv2D::SetActivationType(int activation_type) {
|
|||
}
|
||||
|
||||
#else
|
||||
int DeDepthwiseConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
||||
auto attr = primitive->value_as_DeDepthwiseConv2D();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_DeDepthwiseConv2D return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateDeDepthwiseConv2D(
|
||||
*fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DeDepthwiseConv2D()->format(); }
|
||||
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DeDepthwiseConv2D()->channelIn(); }
|
||||
int DeDepthwiseConv2D::GetChannelMultiplier() const {
|
||||
|
|
|
@ -48,34 +48,9 @@ class DeDepthwiseConv2D : public PrimitiveC {
|
|||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
DeDepthwiseConv2D() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_DeDepthwiseConv2D();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDeDepthwiseConv2D(fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
|
|
|
@ -26,7 +26,19 @@ void DepthToSpace::SetBlockSize(int block_size) { this->primitive_->value.AsDept
|
|||
void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpace()->format = (schema::Format)format; }
|
||||
|
||||
#else
|
||||
|
||||
int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_DepthToSpace();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_DepthToSpace return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDepthToSpace(*fbb, attr->blockSize(), attr->format());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthToSpace, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
|
||||
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
|
||||
|
||||
|
|
|
@ -34,30 +34,9 @@ class DepthToSpace : public PrimitiveC {
|
|||
void SetBlockSize(int block_size);
|
||||
void SetFormat(int format);
|
||||
#else
|
||||
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
DepthToSpace() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_DepthToSpace();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDepthToSpace(fbb, attr->blockSize(), attr->format());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthToSpace, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetBlockSize() const;
|
||||
|
|
|
@ -232,7 +232,22 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int DepthwiseConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_DepthwiseConv2D();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_DepthwiseConv2D return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDepthwiseConv2D(
|
||||
*fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); }
|
||||
int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DepthwiseConv2D()->channelIn(); }
|
||||
int DepthwiseConv2D::GetChannelMultiplier() const {
|
||||
|
|
|
@ -33,7 +33,7 @@ class DepthwiseConv2D : public PrimitiveC {
|
|||
DepthwiseConv2D() = default;
|
||||
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetFormat(int format);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelMultiplier(int channel_multiplier);
|
||||
|
@ -58,35 +58,9 @@ class DepthwiseConv2D : public PrimitiveC {
|
|||
#else
|
||||
|
||||
public:
|
||||
explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
DepthwiseConv2D() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_DepthwiseConv2D();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDepthwiseConv2D(fbb, attr->format(),
|
||||
attr->channelIn(), attr->channelMultiplier(),
|
||||
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
|
||||
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(),
|
||||
attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
|
|
@ -28,9 +28,9 @@ class Dequant : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Dequant, PrimitiveC);
|
||||
Dequant() = default;
|
||||
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
explicit Dequant(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Dequant() = default;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -88,7 +88,22 @@ void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_DetectionPostProcess();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_DetectionPostProcess return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDetectionPostProcess(
|
||||
*fbb, attr->format(), attr->inputSize(), attr->hScale(), attr->wScale(), attr->xScale(), attr->yScale(),
|
||||
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPreClass(),
|
||||
attr->MaxClassesPreDetection(), attr->NumClasses(), attr->UseRegularNms());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int DetectionPostProcess::GetFormat() const { return this->primitive_->value_as_DetectionPostProcess()->format(); }
|
||||
int DetectionPostProcess::GetInputSize() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->inputSize();
|
||||
|
|
|
@ -45,36 +45,9 @@ class DetectionPostProcess : public PrimitiveC {
|
|||
void SetNumClasses(int64_t num_classes);
|
||||
void SetUseRegularNms(bool use_regular_nms);
|
||||
#else
|
||||
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
DetectionPostProcess() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_DetectionPostProcess();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDetectionPostProcess(fbb, attr->format(), attr->inputSize(),
|
||||
attr->hScale(), attr->wScale(),
|
||||
attr->xScale(), attr->yScale(),
|
||||
attr->NmsIouThreshold(), attr->NmsScoreThreshold(),
|
||||
attr->MaxDetections(), attr->DetectionsPreClass(),
|
||||
attr->MaxClassesPreDetection(), attr->NumClasses(),
|
||||
attr->UseRegularNms());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetFormat() const;
|
||||
int GetInputSize() const;
|
||||
|
|
|
@ -26,7 +26,19 @@ void Div::SetActivationType(int activation_type) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Div();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Div return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDiv(*fbb, attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Div, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class Div : public Arithmetic {
|
|||
void SetActivationType(int activation_type);
|
||||
|
||||
#else
|
||||
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
Div() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Div();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDiv(fbb, attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Div, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetActivationType() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,19 @@ float Dropout::GetRatio() const { return this->primitive_->value.AsDropout()->ra
|
|||
void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio = ratio; }
|
||||
|
||||
#else
|
||||
|
||||
int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Dropout();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Dropout return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateDropout(*fbb, attr->ratio());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Dropout, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class Dropout : public PrimitiveC {
|
|||
void SetRatio(float ratio);
|
||||
|
||||
#else
|
||||
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Dropout() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Dropout();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateDropout(fbb, attr->ratio());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Dropout, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetRatio() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,19 @@ int Eltwise::GetMode() const { return this->primitive_->value.AsEltwise()->mode;
|
|||
void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (schema::EltwiseMode)mode; }
|
||||
|
||||
#else
|
||||
|
||||
int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Eltwise();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Eltwise return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateEltwise(*fbb, attr->mode());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Eltwise, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class Eltwise : public PrimitiveC {
|
|||
void SetMode(int mode);
|
||||
|
||||
#else
|
||||
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Eltwise() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Eltwise();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateEltwise(fbb, attr->mode());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Eltwise, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetMode() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,19 @@ float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; }
|
|||
void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; }
|
||||
|
||||
#else
|
||||
|
||||
int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Elu();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Elu return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateElu(*fbb, attr->alpha());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Elu, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class Elu : public PrimitiveC {
|
|||
void SetAlpha(float alpha);
|
||||
|
||||
#else
|
||||
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Elu() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Elu();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateElu(fbb, attr->alpha());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Elu, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetAlpha() const;
|
||||
};
|
||||
|
|
|
@ -24,7 +24,21 @@ float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value.AsEmb
|
|||
void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmbeddingLookup()->maxNorm = max_norm; }
|
||||
|
||||
#else
|
||||
int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
|
||||
auto attr = primitive->value_as_EmbeddingLookup();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_EmbeddingLookup return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateEmbeddingLookup(*fbb, attr->maxNorm());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class EmbeddingLookup : public PrimitiveC {
|
|||
void SetMaxNorm(float max_norm);
|
||||
|
||||
#else
|
||||
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
EmbeddingLookup() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_EmbeddingLookup();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateEmbeddingLookup(fbb, attr->maxNorm());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
float GetMaxNorm() const;
|
||||
|
|
|
@ -38,7 +38,32 @@ void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {
|
|||
}
|
||||
|
||||
#else
|
||||
|
||||
int EmbeddingLookupSparse::UnPackToFlatBuilder(const schema::Primitive *primitive,
|
||||
flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_EmbeddingLookupSparse();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_EmbeddingLookupSparse return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> spIds;
|
||||
if (attr->spIds() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->spIds()->size()); i++) {
|
||||
spIds.push_back(attr->spIds()->data()[i]);
|
||||
}
|
||||
}
|
||||
std::vector<float> spWeights;
|
||||
if (attr->spWeights() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->spWeights()->size()); i++) {
|
||||
spWeights.push_back(attr->spWeights()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateEmbeddingLookupSparseDirect(*fbb, &spIds, &spWeights);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> EmbeddingLookupSparse::GetSpIds() const {
|
||||
auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spIds();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -36,39 +36,9 @@ class EmbeddingLookupSparse : public PrimitiveC {
|
|||
void SetSpWeights(const std::vector<float> &sp_weights);
|
||||
void SetMaxNortm(float max_nortm);
|
||||
#else
|
||||
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
EmbeddingLookupSparse() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_EmbeddingLookupSparse();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto spIds = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->spIds()->size()); i++) {
|
||||
spIds->push_back(attr->spIds()->data()[i]);
|
||||
}
|
||||
auto spWeights = std::make_unique<std::vector<float>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->spWeights()->size()); i++) {
|
||||
spWeights->push_back(attr->spWeights()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema:: CreateEmbeddingLookupSparseDirect(fbb, spIds.release(), spWeights.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
std::vector<int> GetSpIds() const;
|
||||
std::vector<float> GetSpWeights() const;
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/equal.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateEqual(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Equal, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -32,27 +32,9 @@ class Equal : public Arithmetic {
|
|||
Equal() = default;
|
||||
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
Equal() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateEqual(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Equal, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/exp.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateExp(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Exp, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -32,27 +32,9 @@ class Exp : public ArithmeticSelf {
|
|||
Exp() = default;
|
||||
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
Exp() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateExp(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Exp, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -24,7 +24,20 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()->
|
|||
void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; }
|
||||
|
||||
#else
|
||||
int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_ExpandDims();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_ExpandDims return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateExpandDims(*fbb, attr->dim());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ExpandDims, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
|
||||
|
||||
#endif
|
||||
|
|
|
@ -34,30 +34,9 @@ class ExpandDims : public PrimitiveC {
|
|||
void SetDim(int dim);
|
||||
|
||||
#else
|
||||
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
ExpandDims() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_ExpandDims();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateExpandDims(fbb, attr->dim());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ExpandDims, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetDim() const;
|
||||
|
|
|
@ -32,7 +32,21 @@ void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {
|
|||
}
|
||||
|
||||
#else
|
||||
int FakeQuantWithMinMaxVars::UnPackToFlatBuilder(const schema::Primitive *primitive,
|
||||
flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_FakeQuantWithMinMaxVars();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_FakeQuantWithMinMaxVars return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateFakeQuantWithMinMaxVars(*fbb, attr->narrowRange(), attr->numBits());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
bool FakeQuantWithMinMaxVars::GetNarrowRange() const {
|
||||
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->narrowRange();
|
||||
}
|
||||
|
|
|
@ -34,31 +34,9 @@ class FakeQuantWithMinMaxVars : public PrimitiveC {
|
|||
void SetNarrowRange(bool narrow_range);
|
||||
void SetNumBits(int num_bits);
|
||||
#else
|
||||
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
FakeQuantWithMinMaxVars() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_FakeQuantWithMinMaxVars();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateFakeQuantWithMinMaxVars(fbb, attr->narrowRange(), attr->numBits());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb,
|
||||
schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
bool GetNarrowRange() const;
|
||||
int GetNumBits() const;
|
||||
|
|
|
@ -24,7 +24,25 @@ std::vector<int> Fill::GetDims() const { return this->primitive_->value.AsFill()
|
|||
void Fill::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsFill()->dims = dims; }
|
||||
|
||||
#else
|
||||
|
||||
int Fill::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Fill();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Fill return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> dims;
|
||||
if (attr->dims() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->dims()->size()); i++) {
|
||||
dims.push_back(attr->dims()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateFillDirect(*fbb, &dims);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Fill, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> Fill::GetDims() const {
|
||||
auto fb_vector = this->primitive_->value_as_Fill()->dims();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
|
|
|
@ -35,35 +35,9 @@ class Fill : public PrimitiveC {
|
|||
void SetDims(const std::vector<int> &dims);
|
||||
|
||||
#else
|
||||
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Fill() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Fill();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto dims = std::make_unique<std::vector<int32_t>>();
|
||||
for (int i = 0; i < static_cast<int>(attr->dims()->size()); i++) {
|
||||
dims->push_back(attr->dims()->data()[i]);
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateFillDirect(fbb, dims.release());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Fill, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
std::vector<int> GetDims() const;
|
||||
|
|
|
@ -77,6 +77,15 @@ int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateFlatten(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Flatten, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,32 +31,13 @@ class Flatten : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Flatten, PrimitiveC);
|
||||
Flatten() = default;
|
||||
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Flatten() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateFlatten(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Flatten, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/floor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
||||
int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateFloor(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -21,7 +21,7 @@
|
|||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/arithmetic_self.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -32,27 +32,9 @@ class Floor : public ArithmeticSelf {
|
|||
Floor() = default;
|
||||
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
|
||||
#else
|
||||
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
|
||||
Floor() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateFloor(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Floor, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/floor_div.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
||||
int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateFloor(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -32,27 +32,9 @@ class FloorDiv : public Arithmetic {
|
|||
FloorDiv() = default;
|
||||
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
FloorDiv() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateFloorDiv(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorDiv, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/floor_mod.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
||||
int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto val_offset = schema::CreateFloorMod(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FloorMod, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -32,27 +32,9 @@ class FloorMod : public Arithmetic {
|
|||
FloorMod() = default;
|
||||
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
|
||||
#else
|
||||
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
|
||||
FloorMod() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto val_offset = schema::CreateFloorMod(fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorMod, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -31,7 +31,21 @@ void FullConnection::SetActivationType(int activationType) {
|
|||
this->primitive_->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
|
||||
}
|
||||
#else
|
||||
int FullConnection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_FullConnection();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_FullConnection return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset =
|
||||
schema::CreateFullConnection(*fbb, attr->hasBias(), attr->axis(), attr->useAxis(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FullConnection, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
bool FullConnection::GetHasBias() const { return this->primitive_->value_as_FullConnection()->hasBias(); }
|
||||
int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConnection()->axis(); }
|
||||
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
|
||||
|
|
|
@ -36,31 +36,9 @@ class FullConnection : public PrimitiveC {
|
|||
void SetUseAxis(bool use_axis);
|
||||
void SetActivationType(int activationType);
|
||||
#else
|
||||
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
FullConnection() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_FullConnection();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateFullConnection(fbb, attr->hasBias(), attr->axis(),
|
||||
attr->useAxis(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FullConnection, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
bool GetHasBias() const;
|
||||
|
|
|
@ -28,7 +28,20 @@ void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFus
|
|||
void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; }
|
||||
|
||||
#else
|
||||
int FusedBatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_FusedBatchNorm();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_FusedBatchNorm return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateFusedBatchNorm(*fbb, attr->epsilon(), attr->momentum(), attr->spatial());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_FusedBatchNorm()->epsilon(); }
|
||||
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
|
||||
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
|
||||
|
|
|
@ -35,30 +35,9 @@ class FusedBatchNorm : public PrimitiveC {
|
|||
void SetMomentum(float momentum);
|
||||
void SetSpatial(int spatial);
|
||||
#else
|
||||
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
FusedBatchNorm() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_FusedBatchNorm();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateFusedBatchNorm(fbb, attr->epsilon(), attr->momentum(), attr->spatial());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
float GetEpsilon() const;
|
||||
float GetMomentum() const;
|
||||
|
|
|
@ -29,7 +29,20 @@ void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis
|
|||
void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; }
|
||||
|
||||
#else
|
||||
int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Gather();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Gather return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateGather(*fbb, attr->axis(), attr->batchDims());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gather, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
|
||||
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
|
||||
|
||||
|
|
|
@ -34,30 +34,9 @@ class Gather : public PrimitiveC {
|
|||
void SetAxis(int axis);
|
||||
void SetBatchDims(int batch_dims);
|
||||
#else
|
||||
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
|
||||
Gather() = default;
|
||||
|
||||
schema::Primitive *Init(schema::Primitive *primitive) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
|
||||
auto attr = primitive->value_as_Gather();
|
||||
MS_ASSERT(attr != nullptr);
|
||||
|
||||
auto val_offset = schema::CreateGather(fbb, attr->axis(), attr->batchDims());
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Gather, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
MS_ASSERT(buf != nullptr);
|
||||
auto buf_bak = new char[fbb.GetSize()];
|
||||
memcpy(buf_bak, buf, fbb.GetSize());
|
||||
|
||||
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
|
||||
auto prim = const_cast<schema::Primitive *>(root);
|
||||
|
||||
delete[] buf_bak;
|
||||
fbb.Clear();
|
||||
return prim;
|
||||
}
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
|
|
|
@ -24,7 +24,20 @@ int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()
|
|||
void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; }
|
||||
|
||||
#else
|
||||
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_GatherNd();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_GatherNd return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
|
||||
|
||||
#endif
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue