lite expression

This commit is contained in:
yoni 2022-04-14 15:22:36 +03:00 committed by jianghui58
parent 6c9590d13d
commit 35ac798bfa
96 changed files with 6882 additions and 110 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -25,13 +25,13 @@
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
constexpr int iter_th = 1000;
class MixPrecisionCfg { class MixPrecisionCfg {
public: public:
MixPrecisionCfg() { MixPrecisionCfg() {
this->dynamic_loss_scale_ = false; this->dynamic_loss_scale_ = false;
this->loss_scale_ = 128.0f; this->loss_scale_ = 128.0f;
this->num_of_not_nan_iter_th_ = 1000; this->num_of_not_nan_iter_th_ = iter_th;
} }
~MixPrecisionCfg() = default; ~MixPrecisionCfg() = default;
@ -53,6 +53,5 @@ class TrainCfg {
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
bool accumulate_gradients_ = false; bool accumulate_gradients_ = false;
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CFG_H #endif // MINDSPORE_INCLUDE_API_CFG_H

View File

@ -24,23 +24,38 @@
#include "include/api/types.h" #include "include/api/types.h"
namespace mindspore { namespace mindspore {
class NetData;
class Net;
class MS_API Graph { class MS_API Graph {
public: public:
class GraphData; class GraphData;
enum Type : uint32_t {
kExpressionGraph = 0, ///< graph as expression - can auto grad
kExecutableGraph = 1, ///< graph is loaded as is
kUnknownTypeGraph = 0xffffffff
};
Graph(); Graph();
explicit Graph(const std::shared_ptr<GraphData> &graph_data); explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data); explicit Graph(std::shared_ptr<GraphData> &&graph_data);
explicit Graph(std::nullptr_t); explicit Graph(std::nullptr_t);
~Graph(); ~Graph();
explicit Graph(Type executable);
explicit Graph(Net *net);
enum ModelType ModelType() const; enum ModelType ModelType() const;
bool operator==(std::nullptr_t) const; bool operator==(std::nullptr_t) const;
bool operator!=(std::nullptr_t) const; bool operator!=(std::nullptr_t) const;
bool IsExecutable() { return graph_type_ == kExecutableGraph; }
private: private:
friend class GraphCell; friend class GraphCell;
friend class ModelImpl; friend class ModelImpl;
friend class NetImpl;
friend class Model;
std::shared_ptr<GraphData> graph_data_; std::shared_ptr<GraphData> graph_data_;
std::shared_ptr<NetData> net_data_;
Type graph_type_ = kExecutableGraph;
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_GRAPH_H #endif // MINDSPORE_INCLUDE_API_GRAPH_H

View File

@ -33,6 +33,9 @@
namespace mindspore { namespace mindspore {
class ModelImpl; class ModelImpl;
class Metrics; class Metrics;
class Net;
class Node;
class Expr;
namespace dataset { namespace dataset {
class Dataset; class Dataset;
@ -109,6 +112,17 @@ class MS_API Model {
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr, Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr); const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/// \brief Build train model
///
/// \param[in] graph A forward network
/// \param[in] optimizer An optimizer node
/// \param[in] inputs Inputs expression for the trained network (ex: input, label )
/// \param[in] model_contex A context used to store options during execution.
/// \param[in] train_cfg A config used by training
/// \return Status
Status Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs,
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg);
/// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable /// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
/// ///
/// \param[in] backbone The static, non-learnable part of the graph /// \param[in] backbone The static, non-learnable part of the graph

141
include/api/net.h Normal file
View File

@ -0,0 +1,141 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_NET_H
#define MINDSPORE_INCLUDE_API_NET_H
#include <memory>
#include <vector>
#include <unordered_set>
#include <string>
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "include/api/cfg.h"
namespace mindspore {
/// \brief Register node or sub network
#define REG(_name) Register(_name, #_name)
class Expr;
class NodeImpl;
class NetImpl;
class NodeSet;
class Graph;
class NetData;
class NetBase {
public:
NetBase() = default;
virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
virtual uint32_t type() = 0;
};
class Node : public NetBase {
public:
Node();
virtual ~Node();
/// \brief Create output expression from node
/// \param[in] name Name of input (like "labels" etc.)
///
/// \return Expression
Expr *Create(std::string name);
/// \brief Run node on inputs. This operator is used in Net::construct()
///
/// \param[in] inputs Inputs expression for the node.
/// \return Output node expression vector
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
uint32_t type() final;
private:
friend NodeImpl;
std::shared_ptr<NodeImpl> impl_ = nullptr;
};
class Net : public NetBase, public std::enable_shared_from_this<Net> {
public:
Net();
virtual ~Net();
explicit Net(std::string name);
explicit Net(const Graph &g);
/// \brief Define the relation between network inputs and outputs
///
/// \param[in] inputs expression vector
///
/// \return expression vector
virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);
/// \brief Addition operation
///
/// \param[in] inputs Two elements to add
///
/// \return expression vector (single element)
/// \brief Execution operator. Connect inputs to outputs via user defined construct
///
/// \return expression vector
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
void Register(Net *net, std::string &&name);
void Register(Node *node, std::string &&name);
/// \brief Find the trainable params for the trained network
///
/// \return NodeSet for all trainable nodes
std::shared_ptr<NodeSet> trainable_params();
virtual void Add(NetBase *element);
/// \brief Input shape
///
/// \param[in] idx input index
///
/// \return Specific input shape vector
const std::vector<int> InputShape(int idx);
/// \brief Output shape
///
/// \param[in] idx Output index
///
/// \return Specific output shape vector
const std::vector<int> OutputShape(int idx);
uint32_t type() final;
private:
friend NetImpl;
friend NetData;
std::shared_ptr<NetImpl> impl_;
};
class SoftMaxCrossEntropyCfg {
public:
std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */
};
class AdamConfig {
public:
float learning_rate_ = 1e-3;
float beta1_ = 0.9;
float beta2_ = 0.999;
float eps_ = 1e-08;
bool use_nesterov_ = false;
};
namespace NN {
Net *NetWithLoss(Net *net, Node *loss);
Graph *GraphWithLoss(Graph *g, Node *loss);
Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32, int fmt = NHWC);
}; // namespace NN
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_NET_H

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <stdio.h>
#include "nnacl/infer/conv2d_grad_filter_infer.h" #include "nnacl/infer/conv2d_grad_filter_infer.h"
#include "nnacl/infer/infer_register.h" #include "nnacl/infer/infer_register.h"
@ -26,28 +27,33 @@ int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size,
if (inputs_size < 3 || outputs_size != 1) { if (inputs_size < 3 || outputs_size != 1) {
return NNACL_ERR; return NNACL_ERR;
} }
if (inputs[0]->format_ != Format_NHWC || inputs[1]->format_ != Format_NHWC) { if (inputs[FIRST_INPUT]->format_ != Format_NHWC || inputs[SECOND_INPUT]->format_ != Format_NHWC) {
return NNACL_FORMAT_ERROR; return NNACL_FORMAT_ERROR;
} }
SetDataTypeFormat(outputs[0], inputs[0]); SetDataTypeFormat(outputs[FIRST_INPUT], inputs[FIRST_INPUT]);
if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { if (inputs[THIRD_INPUT]->shape_size_ < DIMENSION_1D || inputs[THIRD_INPUT]->data_ == NULL) {
return NNACL_ERR; return NNACL_ERR;
} }
if (inputs[2]->shape_[0] < 0) { if (inputs[THIRD_INPUT]->shape_[kNCHW_N] < 0) {
return NNACL_ERR; return NNACL_ERR;
} }
size_t filter_shape_size = (size_t)(inputs[2]->shape_[0]); size_t filter_shape_size = (size_t)(inputs[THIRD_INPUT]->shape_[kNCHW_N]);
if (filter_shape_size != 4) { if (filter_shape_size != DIMENSION_4D) {
return NNACL_ERR; return NNACL_ERR;
} }
int filter_shape[MAX_SHAPE_SIZE]; int filter_shape[MAX_SHAPE_SIZE];
const int nchw2nhwc[4] = {0, 2, 3, 1}; if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) {
for (size_t i = 0; i < filter_shape_size; i++) { const int nchw2nhwc[] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C};
filter_shape[i] = *((int *)(inputs[2]->data_) + nchw2nhwc[i]); for (size_t i = 0; i < filter_shape_size; i++) {
filter_shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]);
}
} else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) {
memcpy(filter_shape, inputs[THIRD_INPUT]->data_, filter_shape_size * sizeof(int));
} else {
return NNACL_ERR;
} }
SetShapeArray(outputs[0], filter_shape, filter_shape_size); SetShapeArray(outputs[0], filter_shape, filter_shape_size);
return NNACL_OK; return NNACL_OK;
} }

View File

@ -37,20 +37,26 @@ int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size,
} }
SetDataTypeFormat(out, in0); SetDataTypeFormat(out, in0);
if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { if (inputs[THIRD_INPUT]->shape_size_ < 1 || inputs[THIRD_INPUT]->data_ == NULL) {
return NNACL_ERR; return NNACL_ERR;
} }
size_t data_size = (size_t)inputs[2]->shape_[0]; size_t data_size = (size_t)inputs[2]->shape_[0];
if (data_size != 4) { if (data_size != 4) {
return NNACL_ERR; return NNACL_ERR;
} }
int shape[MAX_SHAPE_SIZE]; int shape[MAX_SHAPE_SIZE];
const int nchw2nhwc[4] = {0, 2, 3, 1}; if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) {
for (size_t i = 0; i < data_size; i++) { const int nchw2nhwc[4] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C};
shape[i] = *((int *)(inputs[2]->data_) + nchw2nhwc[i]); for (size_t i = 0; i < data_size; i++) {
shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]);
}
} else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) {
memcpy(shape, inputs[THIRD_INPUT]->data_, data_size * sizeof(int));
} else {
return NNACL_ERR;
} }
SetShapeArray(out, shape, data_size); SetShapeArray(out, shape, data_size);
return NNACL_OK; return NNACL_OK;
} }

View File

@ -157,7 +157,7 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
param->output_h_ = out_shape[DIMENSION_1D]; param->output_h_ = out_shape[DIMENSION_1D];
param->output_w_ = out_shape[DIMENSION_2D]; param->output_w_ = out_shape[DIMENSION_2D];
param->output_channel_ = out_shape[DIMENSION_3D]; param->output_channel_ = out_shape[DIMENSION_3D];
param->out_format_ = out_tensor->format_;
return NNACL_OK; return NNACL_OK;
} }

View File

@ -25,6 +25,7 @@
#include "nnacl/intrinsics/ms_simd_instructions.h" #include "nnacl/intrinsics/ms_simd_instructions.h"
#define C0NUM 0
#define C1NUM 1 #define C1NUM 1
#define C2NUM 2 #define C2NUM 2
#define C3NUM 3 #define C3NUM 3
@ -33,6 +34,8 @@
#define C6NUM 6 #define C6NUM 6
#define C7NUM 7 #define C7NUM 7
#define C8NUM 8 #define C8NUM 8
#define C9NUM 9
#define C10NUM 10
#define C12NUM 12 #define C12NUM 12
#define C13NUM 13 #define C13NUM 13
#define C16NUM 16 #define C16NUM 16

View File

@ -235,7 +235,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
_fullpath(abs_path_buff, file_name.c_str(), PATH_MAX); _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
#else #else
if (!realpath(file_name.c_str(), abs_path_buff)) { if (!realpath(file_name.c_str(), abs_path_buff)) {
MS_LOG(ERROR) << "Load MindIR get absolute path failed"; MS_LOG(ERROR) << "Load MindIR get absolute path of " << file_name << " failed";
} }
#endif #endif
// Read graph // Read graph

View File

@ -275,9 +275,34 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
) )
endif() endif()
file(GLOB CXX_API_EXPRESSION
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/expression/*.cc
)
file(GLOB EXPRESSION_OPS
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops/*.cc
)
set(EXPRESSION_SRC
${CXX_API_EXPRESSION}
${CMAKE_CURRENT_SOURCE_DIR}/expression/export.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/expr.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/import.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/net.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/node.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/param.cc
${CMAKE_CURRENT_SOURCE_DIR}/expression/sequential.cc
${EXPRESSION_OPS}
)
set(TRAIN_SRC set(TRAIN_SRC
${API_TRAIN_SRC} ${API_TRAIN_SRC}
${TRAIN_SRC_WITH_MD} ${TRAIN_SRC_WITH_MD}
${EXPRESSION_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc

View File

@ -0,0 +1,145 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/net.h"
#include "include/api/status.h"
#include "src/cxx_api/expression/node_impl.h"
#include "src/cxx_api/expression/net_impl.h"
#include "src/expression/ops.h"
#include "src/expression/cfg.h"
namespace mindspore {
uint32_t Node::type() { return kNodeType; }
std::vector<Expr *> Node::operator()(const std::vector<Expr *> &inputs) {
auto in = Expr::convert(inputs);
if (impl_ == nullptr) {
MS_LOG(ERROR) << "empty implementation";
return {};
}
if (impl_->node() == nullptr) {
MS_LOG(ERROR) << "expression node is not attached";
return {};
}
auto out = impl_->node()->construct(in);
return Expr::convert(out);
}
Expr *Node::Create(std::string name) {
auto expr = impl_->node()->create(name);
return reinterpret_cast<Expr *>(expr);
}
Node::Node() {
auto impl = std::make_shared<NodeImpl>();
impl_ = impl;
impl_->set_pnode(this);
}
Node::~Node() {
impl_->set_pnode(nullptr);
auto node = impl_->node();
if (node != nullptr) {
impl_->set_node(nullptr);
delete node;
}
}
Net::Net(std::string name) {
auto impl = std::make_shared<NetImpl>();
if (impl == nullptr) {
MS_LOG(ERROR) << "Cannot allocate network implementation";
return;
}
impl_ = impl;
impl_->set_pnet(std::shared_ptr<Net>(this));
auto netl = new (std::nothrow) lite::Net(name);
if (netl == nullptr) {
MS_LOG(ERROR) << "Cannot allocate network lite";
return;
}
netl->set_impl(impl);
impl_->set_net(netl);
}
Net::Net() : Net("") {}
Net::Net(const Graph &g) {
auto net = NetImpl::GetNet(g);
impl_ = net->impl_;
}
void Net::Add(NetBase *element) { MS_LOG(WARNING) << "Only sequential can add element"; }
uint32_t Net::type() { return kNetType; }
std::vector<Expr *> Net::construct(const std::vector<Expr *> &inputs) {
auto in = Expr::convert(inputs);
auto out = impl_->net()->construct(in);
return Expr::convert(out);
}
std::vector<Expr *> Net::operator()(const std::vector<Expr *> &inputs) {
auto in = Expr::convert(inputs);
auto x = construct(inputs);
impl_->net()->input_ = in;
auto out = Expr::convert(x);
impl_->net()->output_ = out;
impl_->net()->real_output_ = out;
return x;
}
void Net::Register(Net *net, std::string &&name) {
if (net != nullptr) {
auto net_lite = net->impl_->net();
impl_->net()->Register(net_lite, std::move(name));
}
}
void Net::Register(Node *node, std::string &&name) {
if (node != nullptr) {
auto impl = NodeImpl::GetImpl(node);
if (impl == nullptr) {
MS_LOG(ERROR) << "missing implementation";
return;
}
auto node_lite = impl->node();
impl_->net()->Register(node_lite, std::move(name));
}
}
std::shared_ptr<NodeSet> Net::trainable_params() {
auto node_set = std::make_shared<NodeSet>();
if (node_set == nullptr) {
MS_LOG(ERROR) << "new NodeSet failed.";
return nullptr;
}
node_set->set_ = impl_->net()->trainable_params();
return node_set;
}
const std::vector<int> Net::InputShape(int idx) { return impl_->InputShape(idx); }
const std::vector<int> Net::OutputShape(int idx) { return impl_->OutputShape(idx); }
Net::~Net() {
if (impl_ != nullptr) {
if ((impl_->pnet() == nullptr) || (impl_->pnet() == this)) {
impl_->set_pnet(nullptr);
impl_->set_net(nullptr);
impl_.reset();
}
}
}
} // namespace mindspore

View File

@ -0,0 +1,222 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/expression/net_impl.h"
#include <vector>
#include <utility>
#include "include/api/serialization.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/model/model_impl.h"
namespace {
constexpr size_t kFlatbuffersBuilderInitSize = 1024;
};
namespace mindspore {
Sequential::Sequential() {}
lite::Node *Sequential::GetNode(NetBase *element) {
lite::Node *lite_node = nullptr;
switch (element->type()) {
case kNodeType: {
Node *node = reinterpret_cast<Node *>(element);
auto impl = NodeImpl::GetImpl(node);
if (impl == nullptr) {
MS_LOG(ERROR) << "cannot find node implement";
return nullptr;
}
lite_node = impl->node();
break;
}
case kNetType: {
auto net = reinterpret_cast<Net *>(element);
auto impl = NetImpl::GetImpl(net);
if (impl == nullptr) {
MS_LOG(ERROR) << "cannot find node implement";
return nullptr;
}
lite_node = impl->net();
break;
}
}
return lite_node;
}
void Sequential::Add(NetBase *element) {
lite::Node *node = GetNode(element);
auto impl = NetImpl::GetImpl(this);
if (impl == nullptr) {
MS_LOG(ERROR) << "No implementation";
return;
}
impl->net()->Add(node);
}
NetWithLoss::NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) {
REG(net_);
Register(loss_fn_, "_loss_fn");
}
std::vector<Expr *> NetWithLoss::construct(const std::vector<Expr *> &inputs) {
if (inputs.size() != C2NUM) {
MS_LOG(ERROR) << "need 2 inputs for loss";
return {};
}
auto input = inputs[FIRST_INPUT];
auto label = inputs[SECOND_INPUT];
auto x = (*net_)({input});
x = (*loss_fn_)({x[FIRST_INPUT], label});
return x;
}
NetImpl::~NetImpl() {}
NetImpl::NetImpl(std::shared_ptr<Net> p) { pnet_ = p; }
NetImpl::NetImpl(Graph *g) { pnet_ = g->net_data_->net(); }
std::vector<lite::EXPR *> MS_API NetImpl::construct(const std::vector<lite::EXPR *> &inputs) {
auto in = Expr::convert(inputs);
auto out = pnet_->construct(in);
return Expr::convert(out);
}
Net *MS_API NetImpl::Connect(std::shared_ptr<Net> net, lite::Net *lnet) {
auto impl = GetImpl(net.get());
if (impl == nullptr) {
MS_LOG(ERROR) << "missing implementation";
return nullptr;
}
impl->set_pnet(net);
lnet->set_impl(impl);
impl->set_net(lnet);
return net.get();
}
Status NetImpl::Import(const char *model_buf, Graph *graph) {
auto mg = schema::GetMetaGraph(model_buf);
auto net = new (std::nothrow) Net();
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate network";
return kLiteMemoryFailed;
}
lite::Import import;
auto lite_net = import.ImportMs(mg);
if (lite_net == nullptr) {
MS_LOG(ERROR) << "failed to import net";
return kLiteMemoryFailed;
}
lite_net->SetRealOutput();
Connect(net->shared_from_this(), lite_net);
*graph = Graph(net);
return kSuccess;
}
Status NetImpl::TrainNet(Node *optimizer, const std::vector<Expr *> &inputs) {
auto impl = NodeImpl::GetImpl(optimizer);
if (impl == nullptr) {
MS_LOG(ERROR) << "missing implementation ";
return kLiteNullptr;
}
auto opt = impl->node();
auto in = Expr::convert(inputs);
auto ret_net = net()->TrainNet(opt, in);
if (ret_net == nullptr) {
MS_LOG(ERROR) << "failed to train network";
return kLiteNullptr;
}
return kSuccess;
}
std::unique_ptr<Graph> NetImpl::MakeMs() {
auto mgraph = std::make_unique<Graph>(Graph::Type::kExecutableGraph);
if (mgraph == nullptr) {
MS_LOG(ERROR) << "Cannot allocate graph";
return nullptr;
}
auto trained_graph = net()->MakeMs();
if (trained_graph == nullptr) {
MS_LOG(ERROR) << "cannot create flat buffer";
return nullptr;
}
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
auto offset = schema::MetaGraph::Pack(builder, trained_graph.get());
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
auto buffer = builder.GetBufferPointer();
size_t size = builder.GetSize();
auto status = Serialization::Load(buffer, size, mindspore::kMindIR, mgraph.get());
if (status != kSuccess) {
MS_LOG(ERROR) << "failed to load flatbuffer to graph";
return nullptr;
}
return mgraph;
}
const std::vector<int> NetImpl::InputShape(int idx) { return net_->InputShape(idx); }
const std::vector<int> NetImpl::OutputShape(int idx) { return net_->OutputShape(idx); }
void NetImpl::ReplaceNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->net().swap(n); }
ExpressionLoader expression_registrator = CreateExpressionLoader(NetImpl::Import);
namespace NN {
Net *Sequential() {
auto net = new (std::nothrow) mindspore::Sequential();
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate ";
return nullptr;
}
auto netl = lite::NN::Sequential();
return NetImpl::Connect(net->shared_from_this(), netl);
}
Net *NetWithLoss(Net *net, Node *loss) {
auto loss_net = new (std::nothrow) mindspore::NetWithLoss(net, loss);
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate loss net";
return nullptr;
}
return loss_net;
}
Graph *GraphWithLoss(Graph *graph, Node *loss) {
auto net = NetImpl::GetNet(*graph);
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate network";
return nullptr;
}
auto loss_net = NetWithLoss(net.get(), loss);
if (loss_net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate network";
return nullptr;
}
NetImpl::ReplaceNet(graph, loss_net->shared_from_this());
return graph;
}
Net *NetWithLoss(Graph *g, Node *loss) {
auto net = new (std::nothrow) Net(*g);
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate net";
return nullptr;
}
return NetWithLoss(net, loss);
}
} // namespace NN
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_
#define MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_
#include <algorithm>
#include <set>
#include <memory>
#include <vector>
#include <utility>
#include "include/api/cfg.h"
#include "include/api/data_type.h"
#include "include/api/graph.h"
#include "include/api/status.h"
#include "include/api/net.h"
#include "src/cxx_api/expression/node_impl.h"
#include "src/cxx_api/graph/net_data.h"
#include "src/expression/net.h"
#include "src/expression/ops.h"
namespace mindspore {
constexpr uint32_t kNodeType = 1;
constexpr uint32_t kNetType = 2;
class Sequential : public Net {
public:
Sequential();
void Add(NetBase *n) override;
private:
std::vector<NetBase *> ops_;
lite::Node *GetNode(NetBase *element);
};
class NetWithLoss : public Net {
public:
NetWithLoss(Net *net, Node *loss);
std::vector<Expr *> construct(const std::vector<Expr *> &inputs) override;
private:
Net *net_{nullptr};
Node *loss_fn_{nullptr};
};
class MS_API NetImpl {
public:
virtual ~NetImpl();
explicit NetImpl(std::shared_ptr<Net> p);
explicit NetImpl(Graph *g);
NetImpl() = default;
void set_net(lite::Net *net) {
if (net_ != nullptr) {
net_->set_impl(nullptr);
delete net_;
}
net_ = net;
}
void erase_net() { net_ = nullptr; }
void set_pnet(std::shared_ptr<Net> net) { pnet_ = net; }
Net *pnet() { return pnet_.get(); }
lite::Net *net() { return net_; }
std::vector<lite::EXPR *> construct(const std::vector<lite::EXPR *> &inputs);
static std::shared_ptr<mindspore::NetImpl> &GetImpl(Net *net) { return net->impl_; }
static Net *Connect(std::shared_ptr<Net> net, lite::Net *lnet);
static std::shared_ptr<Net> &GetNet(const Graph &g) { return g.net_data_->net(); }
static void SetNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->set_net(n); }
static void ReplaceNet(Graph *g, std::shared_ptr<Net> n);
static Status Import(const char *model_buf, Graph *graph);
Status TrainNet(Node *optimizer, const std::vector<Expr *> &inputs);
const std::vector<int> InputShape(int idx);
const std::vector<int> OutputShape(int idx);
std::unique_ptr<Graph> MakeMs();
void Release() { pnet_.reset(); }
private:
std::shared_ptr<Net> pnet_;
lite::Net *net_ = nullptr;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/expression/node_impl.h"
#include <vector>
#include "include/api/net.h"
#include "src/expression/ops.h"
namespace mindspore {
Node *NodeImpl::Connect(lite::Node *lnode) {
auto node = std::make_unique<Node>();
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate node";
return nullptr;
}
if (lnode == nullptr) {
MS_LOG(ERROR) << "lite node is null";
return nullptr;
}
auto pnode = node.release();
auto impl = GetImpl(pnode);
if (impl == nullptr) {
MS_LOG(ERROR) << "missing implementation";
return nullptr;
}
impl->set_node(lnode);
lnode->set_impl(impl);
return pnode;
}
namespace NN {
std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type, int fmt) {
auto type = static_cast<TypeId>(data_type);
auto lite_node = lite::NN::Input(dims, type, fmt);
return std::unique_ptr<Node>(NodeImpl::Connect(lite_node));
}
} // namespace NN
} // namespace mindspore

View File

@ -0,0 +1,71 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_
#define MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_
#include <algorithm>
#include <set>
#include <memory>
#include <vector>
#include "include/api/net.h"
#include "include/api/cfg.h"
#include "include/api/data_type.h"
#include "src/expression/node.h"
#include "src/expression/expr.h"
namespace mindspore {
using lite::EXPR;
class NodeSet {
public:
std::set<lite::Node *> set_;
};
class Expr : public EXPR {
public:
static std::vector<EXPR *> convert(const std::vector<Expr *> &input) {
std::vector<EXPR *> vec(input.size());
(void)std::transform(input.begin(), input.end(), vec.begin(), [](Expr *e) { return reinterpret_cast<EXPR *>(e); });
return vec;
}
static std::vector<Expr *> convert(const std::vector<EXPR *> &input) {
std::vector<Expr *> vec(input.size());
(void)std::transform(input.begin(), input.end(), vec.begin(), [](EXPR *e) { return reinterpret_cast<Expr *>(e); });
return vec;
}
};
class MS_API NodeImpl {
public:
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) {
auto in = Expr::convert(inputs);
auto out = (*node_)(in);
return Expr::convert(out);
}
lite::Node *node() { return node_; }
void set_node(lite::Node *node) { node_ = node; }
void set_pnode(Node *node) { pnode_ = node; }
Node *pnode() { return pnode_; }
static Node *Connect(lite::Node *lnode);
static std::shared_ptr<NodeImpl> &GetImpl(Node *node) { return node->impl_; }
private:
Node *pnode_{nullptr};
lite::Node *node_{nullptr};
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_

View File

@ -16,7 +16,9 @@
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/cell.h" #include "include/api/cell.h"
#include "include/api/net.h"
#include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/graph/net_data.h"
namespace mindspore { namespace mindspore {
Graph::Graph() : graph_data_(nullptr) {} Graph::Graph() : graph_data_(nullptr) {}
@ -25,8 +27,15 @@ Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_d
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {} Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
Graph::Graph(Graph::Type type) : graph_type_(type) {}
Graph::~Graph() {} Graph::~Graph() {}
Graph::Graph(Net *net) : graph_type_(kExpressionGraph) {
auto shared = std::make_shared<NetData>(net->shared_from_this());
net_data_ = shared;
}
Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }

View File

@ -0,0 +1,21 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/graph/net_data.h"
#include "src/cxx_api/expression/net_impl.h"
namespace mindspore {
NetData::~NetData() { net_->impl_->Release(); }
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_
#define MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_
#include <memory>
#include "include/api/net.h"
namespace mindspore {
class NetData {
public:
explicit NetData(const std::shared_ptr<Net> &net) : net_(net) {}
virtual ~NetData();
void set_net(std::shared_ptr<Net> net) { net_ = net; }
std::shared_ptr<Net> &net() { return net_; }
private:
std::shared_ptr<Net> net_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_

View File

@ -22,11 +22,15 @@
#ifdef ENABLE_LITE_ACL #ifdef ENABLE_LITE_ACL
#include "acl/acl_base.h" #include "acl/acl_base.h"
#endif #endif
#include "flatbuffers/flatbuffers.h"
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/api/graph.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/cxx_api/expression/net_impl.h"
#include "src/cxx_api/callback/callback_adapter.h" #include "src/cxx_api/callback/callback_adapter.h"
#include "src/cxx_api/callback/callback_impl.h" #include "src/cxx_api/callback/callback_impl.h"
#include "src/cxx_api/model/model_impl.h" #include "src/cxx_api/model/model_impl.h"

View File

@ -56,6 +56,14 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt
return proto_; return proto_;
} }
ExpressionLoader CreateExpressionLoader(ExpressionLoader loader) {
static ExpressionLoader loader_ = nullptr;
if (loader != nullptr) {
loader_ = loader;
}
return loader_;
}
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type, Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &ms_context) { const std::shared_ptr<Context> &ms_context) {
if (model_data == nullptr) { if (model_data == nullptr) {

View File

@ -49,6 +49,9 @@ typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_
lite::InnerContext *context); lite::InnerContext *context);
CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
using ExpressionLoader = std::function<Status(const char *, Graph *)>;
ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr);
namespace session { namespace session {
class Metrics; class Metrics;
class TrainLoopCallBack; class TrainLoopCallBack;
@ -90,6 +93,7 @@ class ModelImpl {
static bool CheckModelSupport(const std::string &device_type, ModelType model_type); static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
bool IsTrainModel(); bool IsTrainModel();
std::unique_ptr<Graph> BuildTrain(Node *optimizer, std::vector<Expr *> inputs);
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
Status SetLearningRate(float learning_rate); Status SetLearningRate(float learning_rate);
float GetLearningRate(); float GetLearningRate();

View File

@ -20,6 +20,7 @@
#include "include/api/graph.h" #include "include/api/graph.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/model.h" #include "include/model.h"
#include "src/cxx_api/expression/net_impl.h"
#include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/model/model_impl.h" #include "src/cxx_api/model/model_impl.h"
#include "src/cxx_api/converters.h" #include "src/cxx_api/converters.h"
@ -27,6 +28,8 @@
#include "src/lite_session.h" #include "src/lite_session.h"
namespace mindspore { namespace mindspore {
std::function<int(void *)> ExpressionCallback;
Key::Key(const char *dec_key, size_t key_len) { Key::Key(const char *dec_key, size_t key_len) {
len = 0; len = 0;
if (key_len >= max_key_len) { if (key_len >= max_key_len) {
@ -115,19 +118,31 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
MS_LOG(ERROR) << "Read model file failed"; MS_LOG(ERROR) << "Read model file failed";
return kLiteNullptr; return kLiteNullptr;
} }
auto model = if (graph->IsExecutable()) {
std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true)); auto model =
if (model == nullptr) { std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true));
MS_LOG(ERROR) << "New model failed."; if (model == nullptr) {
return kLiteNullptr; MS_LOG(ERROR) << "New model failed.";
return kLiteNullptr;
}
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
if (graph_data == nullptr) {
MS_LOG(ERROR) << "New graph data failed.";
return kLiteMemoryFailed;
}
*graph = Graph(graph_data);
return kSuccess;
} else {
auto loader = CreateExpressionLoader();
if (loader == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature.";
delete[] model_buf;
return kLiteError;
}
loader(model_buf, graph);
delete[] model_buf;
return kSuccess;
} }
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
if (graph_data == nullptr) {
MS_LOG(ERROR) << "New graph data failed.";
return kLiteMemoryFailed;
}
*graph = Graph(graph_data);
return kSuccess;
} }
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type, Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,

View File

@ -17,6 +17,7 @@
#include "include/api/model.h" #include "include/api/model.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "include/api/net.h"
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#include "src/cxx_api/model/model_impl.h" #include "src/cxx_api/model/model_impl.h"
@ -106,11 +107,39 @@ Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCa
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
} }
Status Model::Build(GraphCell lossGraphCell, Node *optimizer, std::vector<Expr *> inputs,
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg) {
std::stringstream err_msg;
if (impl_ == nullptr) {
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
auto lossGraph = lossGraphCell.GetGraph();
if (lossGraph == nullptr) {
err_msg << "Invalid null graph";
MS_LOG(ERROR) << err_msg.str();
return Status(kLiteNullptr, err_msg.str());
}
impl_->SetContext(model_context);
impl_->SetConfig(train_cfg);
impl_->SetGraph(lossGraph);
auto graph = impl_->BuildTrain(optimizer, inputs);
auto status = Build(GraphCell(*graph), model_context, train_cfg);
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Error " << status << " during model build";
return status;
}
return kSuccess; // status
}
Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context, Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
const std::shared_ptr<TrainCfg> &train_cfg) { const std::shared_ptr<TrainCfg> &train_cfg) {
std::stringstream err_msg; std::stringstream err_msg;
if (impl_ == nullptr) { if (impl_ == nullptr) {
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError; return kLiteFileError;
@ -137,5 +166,4 @@ Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const st
} }
return kSuccess; return kSuccess;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -18,15 +18,18 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <algorithm>
#include "src/cxx_api/expression/node_impl.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#include "include/api/serialization.h"
#include "include/lite_session.h" #include "include/lite_session.h"
#include "include/context.h" #include "include/context.h"
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/metrics/metrics.h" #include "include/api/metrics/metrics.h"
#include "src/lite_model.h" #include "src/lite_model.h"
#include "src/runtime/inner_allocator.h" #include "src/runtime/inner_allocator.h"
#include "src/cxx_api/expression/net_impl.h"
#include "src/cxx_api/converters.h" #include "src/cxx_api/converters.h"
#include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/tensor/tensor_impl.h" #include "src/cxx_api/tensor/tensor_impl.h"
@ -79,6 +82,32 @@ Status ModelImpl::BuildTransferLearning(const std::shared_ptr<Graph> &backbone,
return kLiteError; return kLiteError;
} }
std::unique_ptr<Graph> ModelImpl::BuildTrain(Node *optimizer, std::vector<Expr *> inputs) {
auto opt_impl = NodeImpl::GetImpl(optimizer);
if (opt_impl == nullptr) {
MS_LOG(ERROR) << "missing optimizer node implementation";
return nullptr;
}
auto opt = opt_impl->node();
auto in = Expr::convert(inputs);
auto net_impl = NetImpl::GetImpl(graph_->net_data_->net().get());
if (net_impl == nullptr) {
MS_LOG(ERROR) << "missing net implementation";
return nullptr;
}
auto trained_net = net_impl->net()->TrainNet(opt, in);
if (trained_net == nullptr) {
MS_LOG(ERROR) << "failed to train network";
return nullptr;
}
auto mgraph = net_impl->MakeMs();
if (mgraph == nullptr) {
MS_LOG(ERROR) << "failed to create graph";
return nullptr;
}
return mgraph;
}
Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms, Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
std::vector<session::Metrics *> *adapter_ms) { std::vector<session::Metrics *> *adapter_ms) {
if (out_ms == nullptr || adapter_ms == nullptr) { if (out_ms == nullptr || adapter_ms == nullptr) {

View File

@ -0,0 +1,68 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_
#include <vector>
#include <string>
namespace mindspore {
namespace lite {
class ConvConfig {
public:
ConvConfig() = default;
int in_channel_ = 3; /**< The channel number of the input of the Conv2d layer */
int out_channel_ = 3; /**< The channel number of the output tensor of the Conv2d layer */
std::vector<int64_t> kernel_size_ = {3, 3}; /**< Specifies the height and width of the 2D convolution kernel. */
std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D convolution kernel */
std::vector<int64_t> padding_ = {0, 0, 0, 0}; /**< The top, bottom, left, and right padding input */
std::vector<int64_t> dilation_ = {1, 1}; /**< diletion height and width*/
int group_ = 1; // < Splits filter into groups, `in_channels` and `out_channels` must be
// divisible by `group`. If the group is equal to `in_channels` and `out_channels`,
// this 2D convolution layer also can be called 2D depthwise convolution layer */
bool has_bias = false; /** < Whether the Conv2d layer has a bias parameter */
std::string weight_init_ =
"normal"; /**< Initialization method of weight parameter ("normal","uniform", "ones", "zeros") */
std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid", "pad" */
private:
std::string bias_init_ = "zeros";
std::string data_format;
};
class DenseConfig {
public:
int in_channels_; /**< The number of channels in the input space */
int out_channels_; /**< The number of channels in the output space */
bool has_bias_ = false; /** Specifies whether the layer uses a bias vector **/
private:
std::string weight_init_ = "normal";
std::string bias_init_ = "zeros";
std::string activation_ = "none";
};
class PoolingConfig {
public:
PoolingConfig() = default;
std::vector<int64_t> kernel_size_ = {1, 1}; /**< Specifies the height and width of the 2D kernel. */
std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D kernel */
std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid" */
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_

View File

@ -0,0 +1,76 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "src/expression/export.h"
#include "src/expression/ops.h"
#include "src/common/utils.h"
#include "nnacl/conv_parameter.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
constexpr static int kFmkVal = 3;
int ExportSession::Init(const std::string model_name, std::string version) {
meta_graph_ = new (std::nothrow) mindspore::schema::MetaGraphT();
if (meta_graph_ == nullptr) {
MS_LOG(ERROR) << "cannot allocate meta_graph";
return RET_ERROR;
}
meta_graph_->fmkType = kFmkVal;
meta_graph_->name = model_name;
meta_graph_->version = version;
return RET_OK;
}
bool ExportSession::IsToDependOnly(EXPR *expr) {
auto itr = outmap_.find(expr);
if (itr != outmap_.end() && !itr->second.empty()) {
for (auto expr : itr->second) {
auto node = expr->node();
if (node->primitive() != schema::PrimitiveType_Depend) return false;
}
return true;
}
return false;
}
int ExportSession::SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs) {
for (auto &in : inputs) {
auto id = GetOutput(in);
meta_graph_->inputIndex.push_back(id);
}
for (auto &out : outputs) {
auto id = GetOutput(out);
meta_graph_->outputIndex.push_back(id);
}
auto sub_graph = std::make_unique<mindspore::schema::SubGraphT>();
if (sub_graph == nullptr) {
MS_LOG(ERROR) << "cannot allocate SubGraphT";
return RET_ERROR;
}
auto model_name = meta_graph_->name;
sub_graph->name = model_name + "_subgraph";
sub_graph->inputIndices = meta_graph_->inputIndex;
sub_graph->outputIndices = meta_graph_->outputIndex;
for (size_t i = 0; i < meta_graph_->nodes.size(); i++) sub_graph->nodeIndices.push_back(i);
for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) sub_graph->tensorIndices.push_back(i);
meta_graph_->subGraph.emplace_back(std::move(sub_graph));
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_
#include <vector>
#include <memory>
#include <unordered_map>
#include <string>
#include <list>
#include <map>
#include <iostream>
#include "src/expression/expr.h"
namespace mindspore {
namespace schema {
struct MetaGraphT;
}
namespace lite {
class ExportSession {
public:
explicit ExportSession(std::map<EXPR *, std::list<EXPR *>> &outmap) : outmap_(outmap) {}
int Init(const std::string model_name, std::string version);
void UpdateOutput(EXPR *expr, int id) { output_tensors_[expr] = id; }
int GetOutput(EXPR *expr) { return output_tensors_.at(expr); }
schema::MetaGraphT *&meta_graph() { return meta_graph_; }
int SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs);
bool IsToDependOnly(EXPR *expr);
private:
schema::MetaGraphT *meta_graph_{nullptr};
std::unordered_map<EXPR *, int> output_tensors_; // output tensors per EXPR
std::map<EXPR *, std::list<EXPR *>> &outmap_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_

View File

@ -0,0 +1,98 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <iostream>
#include "src/expression/expr.h"
#include "src/expression/node.h"
namespace mindspore {
namespace lite {
std::string EXPR::name() { return node_->name(); }
void EXPR::Travers(std::function<bool(EXPR *e, EXPR *itr)> cb) {
if (!visited) {
visited = true;
for (auto &itr : params_) {
if (cb(this, itr)) {
itr->Travers(cb);
}
}
}
}
void EXPR::Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete) {
if (!visited) {
visited = true;
for (auto &itr : params_)
if (itr == *old) {
to_delete->push_back(itr->node());
itr = *n;
}
for (auto &itr : params_) itr->Replace(old, n, to_delete);
}
}
void EXPR::Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n) {
std::vector<Node *> to_delete;
for (auto &e : vec) {
for (std::size_t i = 0; i < old->size(); i++) {
e->Replace(&old->at(i), &n->at(i), &to_delete);
}
}
for (auto &itr : to_delete) delete itr;
for (auto e : vec) e->Clear();
}
void EXPR::Clear() {
EXPR *item = this;
if (visited == false) return;
visited = false;
while (item->params_.size() == 1) {
item = item->params_.front();
if (item->visited == false) return;
item->visited = false;
}
for (auto &itr : item->params_) itr->Clear();
}
void EXPR::Clear(std::vector<EXPR *> vec) {
for (auto e : vec) e->Clear();
}
void EXPR::CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap) {
for (auto e : vec) {
e->Travers([&](EXPR *e, EXPR *itr) {
(*outmap)[itr].push_back(e);
return true;
});
}
Clear(vec);
}
void EXPR::PrintDot(std::vector<EXPR *> vec) {
std::cout << "digraph \"expr\" { " << std::endl;
for (auto e : vec) {
e->Travers([](EXPR *e, EXPR *itr) {
std::cout << "\"" << itr->node_->name() << "\"->"
<< "\"" << e->node_->name() << "\"" << std::endl;
return true;
});
}
std::cout << "}" << std::endl;
Clear(vec);
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_
#include <vector>
#include <list>
#include <memory>
#include <map>
#include <functional>
#include <string>
#include "include/ms_tensor.h"
#include "include/api/format.h"
namespace mindspore {
namespace lite {
class Node;
class EXPR {
public:
explicit EXPR(Node *node) : node_(node) { SetSize(1); }
static void PrintDot(std::vector<EXPR *> vec);
static void Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n);
static void CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap);
static void Clear(std::vector<EXPR *> vec);
void Travers(std::function<bool(EXPR *e, EXPR *itr)> cb);
std::string name();
EXPR *GetInput(int idx) { return params_.at(idx); }
void set_node(Node *node) { node_ = node; }
Node *node() { return node_; }
bool visited = false;
void set_params(std::vector<EXPR *> params) { params_ = params; }
void set_params(int idx, EXPR *expr) { params_[idx] = expr; }
void add_params(EXPR *e) { params_.push_back(e); }
std::vector<EXPR *> &params() { return params_; }
EXPR *params(int i) { return params_[i]; }
void SetSize(int n) { params_.resize(n); }
void SetDims(std::vector<int> dims) { dims_ = dims; }
std::vector<int> &dims() { return dims_; }
void set_format(int fmt) { format_ = fmt; }
int format() { return format_; }
void set_data_type(TypeId data_type) { data_type_ = data_type; }
TypeId data_type() { return data_type_; }
private:
void Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete);
std::vector<EXPR *> params_;
Node *node_{nullptr};
void Clear();
std::vector<int> dims_;
int format_ = NHWC;
TypeId data_type_ = kNumberTypeFloat32;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_

View File

@ -0,0 +1,180 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/expression/import.h"
#include "ops/populate/populate_register.h"
#include "src/expression/ops.h"
#include "src/expression/ops/activation.h"
#include "src/expression/ops/batchnorm.h"
#include "src/expression/ops/biasadd.h"
#include "src/expression/ops/conv.h"
#include "src/expression/ops/dense.h"
#include "src/expression/ops/pooling.h"
#include "src/expression/ops/reshape.h"
#include "src/expression/ops/transpose.h"
namespace mindspore {
namespace lite {
std::unordered_map<mindspore::schema::PrimitiveType, import_func> ImportReg::import_map_;
import_func ImportReg::GetImportFunc(mindspore::schema::PrimitiveType type) {
auto f = import_map_.find(type);
if (f == import_map_.end()) {
return nullptr;
}
return f->second;
}
OpParameter *Import::GetAttr(const schema::Primitive *prim) {
auto parameter_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
if (parameter_gen == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
return nullptr;
}
auto parameter = parameter_gen(prim);
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
return nullptr;
}
return parameter;
}
std::unique_ptr<Node> Import::CreateNode(const schema::CNode *cnode) {
auto param = GetAttr(cnode->primitive());
auto type = cnode->primitive()->value_type();
auto fn = ImportReg::GetImportFunc(type);
if (fn == nullptr) {
MS_LOG(ERROR) << "Cannot find importer for " << schema::EnumNamePrimitiveType(type);
return nullptr;
}
auto node = fn();
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate node" << cnode->name()->str();
return nullptr;
}
node->SetOpParam(param);
node->set_name(cnode->name()->str());
node->set_primitive(type);
return std::unique_ptr<Node>(node);
}
Net *Import::ImportMs(std::string file_name) {
std::ifstream infile;
infile.open(file_name, std::ios::binary | std::ios::in);
if (!infile.good()) {
MS_LOG(ERROR) << "cannot read " << file_name << std::endl;
return nullptr;
}
infile.seekg(0, std::ios::end);
int length = infile.tellg();
infile.seekg(0, std::ios::beg);
auto data_ptr = std::make_unique<int8_t[]>(length);
auto *data = data_ptr.get();
infile.read(reinterpret_cast<char *>(data), length);
infile.close();
flatbuffers::Verifier verifier = flatbuffers::Verifier(reinterpret_cast<const uint8_t *>(data), length);
bool res = schema::VerifyMetaGraphBuffer(verifier);
if (res != true) {
MS_LOG(ERROR) << "fault file: " << file_name << "(" << length << ")\n";
return nullptr;
} else {
MS_LOG(INFO) << "verify pass file: " << file_name << "(" << length << ")\n";
}
buffer_ = data_ptr.get();
auto metaGraph = schema::GetMetaGraph(data_ptr.release());
return ImportMs(metaGraph);
}
Net *Import::ImportMs(const schema::MetaGraph *metaGraph) {
if (metaGraph == nullptr) {
MS_LOG(ERROR) << "null input";
return nullptr;
}
std::string NetName = "Network";
if (metaGraph->name() != nullptr) NetName = metaGraph->name()->str();
auto net = std::make_unique<Net>(NetName);
std::unordered_map<int, EXPR *> outputs;
// save inputs
for (size_t i = 0; i < metaGraph->inputIndex()->size(); i++) {
auto tensor_id = metaGraph->inputIndex()->Get(i);
const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id);
auto input = new (std::nothrow) InputM(tensor);
if (input == nullptr) {
MS_LOG(ERROR) << "Cannot allocate input";
return nullptr;
}
auto e = input->expr();
outputs[tensor_id] = e;
net->PushInput(e);
}
for (size_t i = 0; i < metaGraph->nodes()->size(); i++) {
auto Cnode = metaGraph->nodes()->Get(i);
std::vector<EXPR *> param_tensors;
for (size_t j = 0; j < Cnode->inputIndex()->size(); j++) {
int tensor_id = Cnode->inputIndex()->Get(j);
const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id);
auto iter = outputs.find(tensor_id);
if (iter == outputs.end()) {
// create value node if not exist
if (tensor->nodeType() != NodeType::NodeType_CNode) {
auto valnode = new (std::nothrow) InputM(tensor);
if (valnode == nullptr) {
MS_LOG(ERROR) << "Cannot allocate valnode";
return nullptr;
}
outputs[tensor_id] = valnode->expr();
param_tensors.push_back(valnode->expr());
net->PushOp(valnode);
} else {
MS_LOG(ERROR) << "did not found input tensor " << tensor_id;
return nullptr;
}
} else {
param_tensors.push_back(iter->second);
}
}
// create expression from node //
auto node = CreateNode(Cnode);
if (node != nullptr) {
node->SetOutputs(Cnode->outputIndex()->size());
std::vector<EXPR *> e = (*node)(param_tensors);
for (size_t j = 0; j < Cnode->outputIndex()->size(); j++) {
int tensor_id = Cnode->outputIndex()->Get(j);
outputs[tensor_id] = e.at(j);
}
} else {
MS_LOG(ERROR) << "failed to create node " << Cnode->name();
return nullptr;
}
auto node_ptr = node.release();
net->PushOp(node_ptr);
node_ptr->SetLearn();
}
for (size_t i = 0; i < metaGraph->outputIndex()->size(); i++) {
auto tensor_id = metaGraph->outputIndex()->Get(i);
auto iter = outputs.find(tensor_id);
if (iter == outputs.end()) {
MS_LOG(ERROR) << "could not find source for tensor " << tensor_id;
return nullptr;
} else {
net->PushOutput(iter->second);
}
}
return net.release();
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_
#include <string>
#include <unordered_map>
#include <memory>
#include "nnacl/op_base.h"
#include "src/expression/net.h"
namespace mindspore {
namespace lite {
using import_func = std::function<Node *()>;
template <typename T>
Node *ReturnNode() {
return new (std::nothrow) T();
}
class ImportReg {
public:
explicit ImportReg(mindspore::schema::PrimitiveType type, import_func func) { import_map_[type] = func; }
static import_func GetImportFunc(mindspore::schema::PrimitiveType type);
private:
static std::unordered_map<mindspore::schema::PrimitiveType, import_func> import_map_;
};
class Import {
private:
int8_t *buffer_ = nullptr;
OpParameter *GetAttr(const schema::Primitive *prim);
std::unique_ptr<Node> CreateNode(const schema::CNode *cnode);
public:
Net *ImportMs(const schema::MetaGraph *meta_graph);
Net *ImportMs(std::string file_name);
~Import() {
delete[] buffer_;
buffer_ = nullptr;
}
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_

View File

@ -0,0 +1,268 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/net.h"
#include <vector>
#include "src/cxx_api/expression/net_impl.h"
#include "src/expression/ops.h"
#include "src/expression/export.h"
#include "src/expression/ops/addn.h"
#include "src/expression/ops/arithmetic.h"
#include "src/common/storage.h"
#include "tools/common/meta_graph_serializer.h"
namespace mindspore {
namespace lite {
void Net::update_name(std::string name) {
if (!this->name().empty())
Node::update_name(name);
else
set_name(name);
for (auto &itr : ops_) {
itr->update_name(name);
}
}
std::vector<EXPR *> Net::operator()(const std::initializer_list<EXPR *> &&inputs) {
std::vector<EXPR *> vec = inputs;
std::vector<EXPR *> x;
if (impl_ == nullptr) {
x = construct(inputs);
} else {
x = impl_->construct(vec);
}
return x;
}
std::vector<EXPR *> Net::operator()(const std::vector<EXPR *> &inputs) {
std::vector<EXPR *> x;
if (impl_ == nullptr) {
x = construct(inputs);
} else {
x = impl_->construct(inputs);
}
input_ = inputs;
output_ = x;
real_output_ = x;
return x;
}
std::vector<EXPR *> Net::construct(const std::vector<EXPR *> &inputs) {
if (!output_.empty()) {
if (input_.size() != inputs.size()) {
MS_LOG(ERROR) << "input size mismatch, should be " << input_.size() << " got " << inputs.size();
return {};
}
auto in_ptr = inputs;
EXPR::Replace(output_, &input_, &in_ptr);
} else {
MS_LOG(ERROR) << "no network construction function";
}
return output_;
}
void Net::TopoSortUtil(Node *node, std::stack<Node *> *stack) {
visited_.insert(node);
for (size_t i = 0; i < node->OutputsNum(); i++) {
auto expr = node->expr(i);
auto itr = outmap_.find(expr);
if (itr != outmap_.end()) {
for (auto &e : itr->second)
if (visited_.find(e->node()) == visited_.end()) {
TopoSortUtil(e->node(), stack);
}
}
}
stack->push(node);
}
std::vector<Node *> Net::Sort() {
std::stack<Node *> stack;
outmap_.clear();
EXPR::CreateOutputMap(output_, &outmap_);
for (auto &itr : outmap_) {
EXPR *e = itr.first;
if (visited_.find(e->node()) == visited_.end()) {
TopoSortUtil(e->node(), &stack);
}
}
std::vector<Node *> res;
while (stack.empty() == false) {
res.push_back(stack.top());
stack.pop();
}
visited_.clear();
return res;
}
std::unique_ptr<schema::MetaGraphT> Net::MakeMs() {
auto nodes = Sort();
auto s = new (std::nothrow) ExportSession(outmap_);
if (s == nullptr) {
MS_LOG(ERROR) << "Cannot allocate export session";
return nullptr;
}
session_.reset(s);
session_->Init(name(), Version());
for (auto node : nodes) {
auto res = node->MakeEntry(session_.get());
if (res != RET_OK) {
MS_LOG(ERROR) << "failed in MakeEntry: " << node->name();
return nullptr;
}
}
session_->SetInputOutput(input_, real_output_);
auto res = session_->meta_graph();
return std::unique_ptr<schema::MetaGraphT>(res);
}
std::unique_ptr<schema::MetaGraphT> Net::MakeMs(const std::string file_name) {
auto graph = MakeMs();
Save(*graph, file_name);
return graph;
}
std::set<Node *> Net::trainable_params() {
std::set<Node *> res;
for (auto &node : ops_) {
res.merge(node->trainable_params());
}
return res;
}
int Net::BuildGrad(Node *optimizer) {
std::set<Node *> learn = optimizer->trainable_params();
auto NetOrder = Sort();
optimizer_.reset(optimizer);
optimizer->AddNetOutput(&output_);
std::map<std::pair<EXPR *, EXPR *>, EXPR *> backprop;
for (auto itr = NetOrder.rbegin(); itr != NetOrder.rend(); itr++) {
Node *node = *itr;
EXPR *yt = nullptr;
if (node->primitive() == schema::PrimitiveType_NONE) continue;
if (outmap_.find(node->expr()) == outmap_.end() || outmap_[node->expr()].size() == 0) {
yt = node->expr();
} else {
std::vector<EXPR *> add_params;
for (auto &output : outmap_[node->expr()]) {
auto link = std::make_pair(node->expr(), output);
auto grad = backprop[link];
add_params.push_back(grad);
}
if (add_params.size() == 1) {
yt = add_params.front();
} else {
auto addn = new (std::nothrow) AddN(0);
if (addn == nullptr) {
MS_LOG(ERROR) << "Cannot allocate add operator";
return RET_ERROR;
}
PushOp(addn);
addn->update_name(name());
yt = (*addn)(add_params).front();
}
}
auto inGrads = node->Grad(yt);
for (size_t i = 0; i < node->inputs().size(); i++) {
EXPR *inGrad{nullptr};
if (i < inGrads.size()) {
inGrad = inGrads[i];
} else {
inGrad = nullptr;
}
auto input = node->input(i);
if (learn.find(input->node()) != learn.end()) {
auto opt = optimizer->Clone(inGrad, input);
if (opt.size() == 0) {
MS_LOG(ERROR) << "failed to create optimizer";
return RET_ERROR;
}
if (inGrad == nullptr) {
MS_LOG(ERROR) << "illegal null value for grad";
return RET_ERROR;
}
if (opt.size() == 0) {
MS_LOG(ERROR) << "optimizer for " << input->node()->name() << " failure";
return RET_ERROR;
}
auto opt_op = opt.at(0)->node();
PushOp(opt_op);
opt_op->update_name(node->name());
output_.push_back(opt.at(0));
}
auto link = std::make_pair(input, node->expr());
backprop[link] = inGrad;
}
}
return RET_OK;
}
std::vector<EXPR *> Net::add(const std::vector<EXPR *> &input) {
auto _add = NN::Add();
_add->set_name(name() + "/" + _add->name());
ops_.push_back(_add);
return (*_add)(input);
}
Net *Net::TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs) {
auto net = new (std::nothrow) NetWithLoss(this, loss_fn);
if (net == nullptr) {
MS_LOG(ERROR) << "Cannot allocate loss network";
return nullptr;
}
return net->TrainNet(optimizer, inputs);
}
Net *Net::TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs) {
auto x = (*this)(inputs);
auto res = BuildGrad(optimizer);
if (res != RET_OK) {
MS_LOG(ERROR) << "Build gradient network failed";
return nullptr;
}
real_output_ = x;
return this;
}
int Net::Save(const schema::MetaGraphT &graph, std::string file_name) { return Storage::Save(graph, file_name); }
const std::vector<int> Net::OutputShape(int idx) {
if (static_cast<size_t>(idx) >= real_output_.size()) {
MS_LOG(ERROR) << "index (" << idx << ") exceed output size (" << real_output_.size() << ")";
return {};
}
return real_output_.at(idx)->dims();
}
const std::vector<int> Net::InputShape(int idx) {
if (static_cast<size_t>(idx) >= input_.size()) {
MS_LOG(ERROR) << "index (" << idx << ") exceed input size (" << input_.size() << ")";
return {};
}
return input_.at(idx)->dims();
}
Net::~Net() {
if (impl_ != nullptr) {
impl_->erase_net();
auto pnet = impl_->pnet();
if (pnet != nullptr) {
impl_->set_pnet(nullptr);
}
}
impl_ = nullptr;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NET_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_NET_H_
#include <stack>
#include <memory>
#include <set>
#include <map>
#include <utility>
#include <string>
#include <unordered_set>
#include <list>
#include <vector>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
class Net;
class NetImpl;
namespace lite {
#define REG(_name) Register(_name, #_name)
class ExportSession;
class Net : public Node {
public:
Net() = default;
virtual ~Net();
explicit Net(std::string name) : Node(name) {}
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) override;
std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) override;
void update_name(std::string name) override;
Net *TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs);
Net *TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs);
void PrintDot() { EXPR::PrintDot(output_); }
void PushOutput(EXPR *e) { output_.push_back(e); }
void PushInput(EXPR *e) { input_.push_back(e); }
void SetRealOutput() { real_output_ = output_; }
std::set<Node *> trainable_params() override;
std::vector<Node *> Sort();
int BuildGrad(Node *optimizer);
int BuildGrad(Node *optimizer, std::set<Node *> learnable);
std::unique_ptr<schema::MetaGraphT> MakeMs();
std::unique_ptr<schema::MetaGraphT> MakeMs(std::string file_name);
schema::MetaGraph *meta_graph() { return meta_graph_; }
int Save(const schema::MetaGraphT &graph, const std::string filename);
void set_impl(std::shared_ptr<mindspore::NetImpl> impl) { impl_ = impl; }
const std::vector<int> InputShape(int idx);
const std::vector<int> OutputShape(int idx);
protected:
std::vector<EXPR *> add(const std::vector<EXPR *> &input);
void Register(Node *node, std::string &&name) {
if (node != nullptr) {
PushOp(node);
node->update_name(name);
}
}
private:
friend mindspore::Net;
std::unordered_set<Node *> visited_;
std::map<EXPR *, std::list<EXPR *>> outmap_; // outputs per expression
std::map<EXPR *, std::list<EXPR *>> inmap_; // inputs per expression
std::vector<EXPR *> output_; // network output expression
std::vector<EXPR *> real_output_; // network output for export
std::vector<EXPR *> input_; // network input expression
schema::MetaGraph *meta_graph_; // imported meta_graph
std::unique_ptr<ExportSession> session_; // export session
std::unique_ptr<Node> optimizer_;
void TopoSortUtil(Node *v, std::stack<Node *> *stack);
void CreateOutputMap(std::vector<EXPR *> vec, std::map<Node *, std::list<Node *>> *outmap);
std::shared_ptr<mindspore::NetImpl> impl_;
};
class NetWithLoss : public Net {
public:
NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) {
REG(net_);
REG(loss_fn_);
loss_fn_->set_name("_loss_fn");
}
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) {
auto input = inputs[0];
auto label = inputs[1];
auto x = (*net_)({input});
x = (*loss_fn_)({x[0], label});
return {x.front()};
}
private:
Net *net_{nullptr};
Node *loss_fn_{nullptr};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_NET_H_

View File

@ -0,0 +1,271 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <utility>
#include <functional>
#include "src/expression/node.h"
#include "src/expression/ops.h"
#include "src/expression/export.h"
#include "src/runtime/infer_manager.h"
#include "src/common/utils.h"
#include "src/cxx_api/expression/net_impl.h"
namespace mindspore {
namespace lite {
int Node::name_id;
std::vector<EXPR *> Node::construct(const std::vector<EXPR *> &inputs) {
if (inputs.size() >= expr()->params().size()) {
expr()->set_params(inputs);
} else {
for (std::size_t i = 0; i < inputs.size(); i++) {
expr()->set_params(i, inputs[i]);
}
}
auto ret = InferShape();
if (ret != RET_OK) {
MS_LOG(ERROR) << "error infershape for node " << name();
return {};
}
std::vector<EXPR *> res(expr_.size());
(void)std::transform(expr_.begin(), expr_.end(), res.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); });
return res;
}
std::vector<EXPR *> Node::Grad(EXPR *expr) {
MS_LOG(ERROR) << name() << " (" << schema::EnumNamePrimitiveType(primitive()) << ") does not have grad defined";
return {};
}
int Node::CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input) {
MS_ASSERT(tensors != nullptr);
int ret = RET_OK;
for (auto e : expr) {
// Tensor -> TensorC
if (is_input && e->node()->primitive() == schema::PrimitiveType_Depend) {
continue;
}
auto type = (e->node()->primitive() != schema::PrimitiveType_NONE) ? Category::VAR : Category::CONST_TENSOR;
auto t = std::make_unique<Tensor>(e->data_type(), e->dims(), (mindspore::Format)e->format(), type);
if (t == nullptr) {
ret = RET_NULL_PTR;
break;
}
// copy data if any
if (type == Category::CONST_TENSOR) {
void *dst = t->MutableData();
if (dst == nullptr) {
ret = RET_NULL_PTR;
break;
}
if (e->node()->data() && (e->node()->data()->data().size() > 0)) {
uint8_t *src = e->node()->data()->data().data();
memcpy(dst, src, t->Size());
}
}
tensors->push_back(t.release());
}
return ret;
}
void Node::FreeAllTensors(std::vector<Tensor *> *tensors) {
MS_ASSERT(tensors != nullptr);
for (auto &t : *tensors) {
delete t;
}
tensors->clear();
}
int Node::InferShape() {
auto ret = RET_OK;
std::vector<Tensor *> in_tensors;
std::vector<Tensor *> out_tensors;
// build in \ out tensors
ret = CreateTensorFromExpr(expr()->params(), &in_tensors, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed in create in tensors";
FreeAllTensors(&in_tensors);
return RET_ERROR;
}
std::vector<EXPR *> expr(expr_.size());
(void)std::transform(expr_.begin(), expr_.end(), expr.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); });
ret = CreateTensorFromExpr(expr, &out_tensors);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed in create out tensors";
FreeAllTensors(&in_tensors);
FreeAllTensors(&out_tensors);
return RET_ERROR;
}
// Do infer Shape
ret = KernelInferShape(in_tensors, out_tensors, OpParam());
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed in infer shape for " << name();
FreeAllTensors(&in_tensors);
FreeAllTensors(&out_tensors);
return RET_ERROR;
}
// copy infer shape into expr
for (uint32_t i = 0; i < expr_.size(); i++) {
auto e = &expr_.at(i);
auto o = out_tensors.at(i);
e->set_format((o->format()));
e->set_data_type(o->data_type());
e->SetDims(o->shape());
}
// cleanup
FreeAllTensors(&in_tensors);
FreeAllTensors(&out_tensors);
return ret;
}
EXPR *Node::CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name) {
auto weights = new (std::nothrow) InputM(dims);
if (weights == nullptr) {
MS_LOG(ERROR) << "Cannot allocate weights";
return nullptr;
}
weights->set_name(this->name() + "/" + name);
int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>());
weights->data()->SetSize(size);
weights->data()->Fill(mode);
PushOp(weights);
return weights->expr();
}
Node *Node::CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name,
const void *data) {
auto tensor = NN::Input(dims, data_type, format);
int elem_size = DataTypeSize(data_type);
tensor->set_name(this->name() + "/" + name);
int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()) * elem_size;
tensor->data()->SetSize(size);
tensor->data()->Copy(reinterpret_cast<const uint8_t *>(data), size);
expr()->set_params(index, tensor->expr());
PushOp(tensor);
return tensor;
}
int Node::MakeEntry(ExportSession *session) {
std::vector<uint32_t> input_idx;
std::vector<uint32_t> output_idx;
std::vector<uint8_t> empty;
if (primitive() == schema::PrimitiveType_Depend) return RET_OK;
// create node input
size_t inputs = InputsNum();
for (size_t i = 0; i < inputs; i++) {
EXPR *ex = expr()->GetInput(i);
if (ex->node()->primitive() == schema::PrimitiveType_Depend) continue;
uint32_t id = session->GetOutput(ex);
input_idx.push_back(id);
}
size_t outputs = OutputsNum();
size_t last_id = session->meta_graph()->allTensors.size();
int type = (primitive() == schema::PrimitiveType_NONE) ? NodeType_ValueNode : NodeType_CNode;
auto data = (type == NodeType_ValueNode) ? this->data()->data() : empty;
if (data.empty()) type = NodeType_CNode; // input is Cnode !!?
int idx = 0;
for (size_t i = 0; i < outputs; i++) {
if (session->IsToDependOnly(expr(i))) continue;
output_idx.push_back(last_id + idx);
session->UpdateOutput(expr(i), last_id + idx);
auto odims = dims(i);
auto data_type = expr(i)->data_type();
auto format = expr(i)->format();
std::string footer = (i > 0) ? ("-" + std::to_string(i)) : "";
auto otensor = CreateTensor(name() + footer, type, data_type, odims, format, data);
std::cout << "tensor -" << last_id + idx << ": " << name() + footer << std::endl;
idx++;
session->meta_graph()->allTensors.emplace_back(std::move(otensor));
}
if (primitive() != schema::PrimitiveType_NONE) {
if (output_idx.size() == 0) {
return RET_OK;
}
auto cnode = CreateCNode(input_idx, output_idx);
auto ret = UnPopulate(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to populate cnode";
return RET_ERROR;
}
session->meta_graph()->nodes.emplace_back(std::move(cnode));
}
return RET_OK;
}
std::unique_ptr<schema::CNodeT> Node::CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex) {
auto cnode = std::make_unique<schema::CNodeT>();
cnode->primitive = std::make_unique<schema::PrimitiveT>();
cnode->primitive->value.type = primitive();
cnode->name = name();
cnode->inputIndex = inputIndex;
cnode->outputIndex = outputIndex;
return cnode;
}
int Node::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
MS_LOG(ERROR) << "Node " << schema::EnumNamePrimitiveType(primitive()) << " cannot be exported";
return RET_ERROR;
}
std::unique_ptr<mindspore::schema::TensorT> Node::CreateTensor(std::string name, int type, int data_type,
const std::vector<int32_t> dims, int format,
const std::vector<uint8_t> &data) {
auto tensorT = std::make_unique<mindspore::schema::TensorT>();
tensorT->nodeType = type;
tensorT->dims = dims;
tensorT->format = static_cast<schema::Format>(format);
tensorT->name = name;
tensorT->refCount = 0;
tensorT->offset = 0;
tensorT->dataType = data_type;
tensorT->data = data;
tensorT->enableHuffmanCode = false;
if (tensorT->nodeType == mindspore::lite::NodeType_ValueNode) {
tensorT->data = data;
}
return tensorT;
}
int Node::SetOutputs(int num) {
EXPR e(this);
e.SetSize(0);
for (auto i = expr_.size(); i < static_cast<size_t>(num); i++) {
expr_.emplace_back(e);
}
return RET_OK;
}
Node::~Node() {
for (auto &op : ops_) {
delete op;
}
ops_.clear();
if (impl_ != nullptr) {
impl_->set_node(nullptr);
auto pnode = impl_->pnode();
if (pnode != nullptr) {
impl_->set_pnode(nullptr);
delete pnode;
}
}
impl_ = nullptr;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,156 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_
#include <stdlib.h>
#include <vector>
#include <string>
#include <iostream>
#include <memory>
#include <set>
#include "src/expression/export.h"
#include "inner/model_generated.h"
#include "src/expression/param.h"
#include "src/expression/expr.h"
#include "src/tensor.h"
#include "nnacl/op_base.h"
namespace mindspore {
class NodeImpl;
namespace schema {
struct TensorT;
struct CNodeT;
} // namespace schema
namespace lite {
class Node {
public:
const std::string kGradName = "Gradients";
explicit Node(const std::string name) : opParam_(nullptr), name_(name) { expr_.emplace_back(this); }
virtual ~Node();
Node() : Node("") {}
explicit Node(Node *node) : Node(*node) {}
EXPR *create(std::string name) {
name_ = name;
return &expr_[0];
}
virtual std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) {
auto x = construct(inputs);
return x;
}
virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) {
std::vector<EXPR *> vec = inputs;
auto x = construct(vec);
return x;
}
virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &inputs) {
std::vector<EXPR *> vec = inputs;
auto x = construct(vec);
return x;
}
void set_primitive(schema::PrimitiveType primitive) {
primitive_ = primitive;
if (OpParam() != nullptr) opParam_->type_ = primitive_;
}
schema::PrimitiveType primitive() { return primitive_; }
virtual std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs);
std::string name() { return name_; }
void set_name(std::string name) { name_ = name; }
virtual void update_name(std::string name) { set_name(name + "/" + name_); }
size_t Load(std::string file_name, size_t offset = 0) { return offset; }
OpParameter *OpParam() const { return opParam_.get(); }
virtual void Add(Node *node) {}
virtual std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) { return {}; }
void SetOpParam(std::shared_ptr<OpParameter> opParam) { opParam_ = opParam; }
void SetOpParam(void *opParam) { opParam_.reset(reinterpret_cast<OpParameter *>(opParam), free); }
static std::string UniqueName(const std::string &name) { return name + "-" + std::to_string(name_id++); }
static std::string UniqueName(std::string &&name) { return name + "-" + std::to_string(name_id++); }
template <typename T>
int CloneOpParam(std::shared_ptr<OpParameter> opParam) {
auto t = reinterpret_cast<T *>(opParam.get());
auto obj = new (std::nothrow) T(*t); // copy content
if (obj == nullptr) {
MS_LOG(ERROR) << "Cannot allocate obj";
return RET_ERROR;
}
opParam_.reset(reinterpret_cast<OpParameter *>(obj));
return RET_OK;
}
template <typename T>
int CloneOpParam(OpParameter *opParam) {
auto t = reinterpret_cast<T *>(opParam);
auto obj = new (std::nothrow) T(*t); // copy content
if (obj == nullptr) {
MS_LOG(ERROR) << "Cannot allocate obj";
return RET_ERROR;
}
opParam_.reset(reinterpret_cast<OpParameter *>(obj));
return RET_OK;
}
virtual Param *weight() { return nullptr; }
EXPR *expr(int i) { return &expr_[i]; }
EXPR *expr() { return expr(0); }
std::vector<EXPR *> inputs() { return expr()[0].params(); }
size_t InputsNum() { return expr()[0].params().size(); }
size_t OutputsNum() { return expr_.size(); }
EXPR *input(int idx) { return expr()[0].params().at(idx); }
EXPR *output(int idx) { return expr(idx); }
EXPR *CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name);
Node *CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name,
const void *data);
virtual std::vector<EXPR *> Grad(EXPR *expr);
virtual Param *data() { return nullptr; }
bool IsLearn(Node *node) { return learnable_.find(node) != learnable_.end(); }
virtual void SetLearn() {}
virtual std::set<Node *> trainable_params() { return learnable_; }
std::vector<int> &dims() { return expr()->dims(); }
std::vector<int> &dims(int i) { return expr(i)->dims(); }
// export
int MakeEntry(ExportSession *session);
void PushOp(Node *n) { ops_.push_back(n); }
virtual void AddNetOutput(std::vector<EXPR *> *output) {}
int SetOutputs(int num);
std::shared_ptr<OpParameter> opParam_;
void set_impl(std::shared_ptr<NodeImpl> impl) { impl_ = impl; }
protected:
std::vector<EXPR> expr_; // hold outputs
std::vector<Node *> ops_; // all nodes or subnets
int InferShape();
void AddLearn(Node *node) { learnable_.insert(node); }
void AssignLearn(std::set<Node *> &&learn) { learnable_ = learn; }
std::unique_ptr<schema::CNodeT> CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex);
virtual int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode);
std::unique_ptr<schema::TensorT> CreateTensor(std::string name, int type, int data_type,
const std::vector<int32_t> dims, int format,
const std::vector<uint8_t> &data);
private:
int CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input = false);
void FreeAllTensors(std::vector<Tensor *> *tensors);
static int name_id;
std::set<Node *> learnable_; // set of nodes with learnable parameters
std::string name_;
schema::PrimitiveType primitive_;
std::shared_ptr<NodeImpl> impl_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <algorithm>
#include "src/expression/ops.h"
#include "src/expression/ops_utils.h"
#include "src/expression/param.h"
#include "include/api/cfg.h"
#include "src/expression/sequential.h"
namespace mindspore {
namespace lite {
void InputM::SetUp(const std::vector<int> &dims, TypeId data_type, int fmt) {
expr()->SetSize(C0NUM);
expr()->SetDims(dims);
expr()->set_data_type(data_type);
expr()->set_format(fmt);
set_primitive(schema::PrimitiveType_NONE);
}
InputM::InputM(const std::vector<int> &dims, TypeId data_type, int fmt) : Node() { SetUp(dims, data_type, fmt); }
InputM::InputM(const schema::Tensor *tensor) : Node() {
std::vector<int> dims(tensor->dims()->size());
(void)std::transform(tensor->dims()->begin(), tensor->dims()->end(), dims.begin(), [](int32_t x) { return x; });
SetUp(dims, static_cast<TypeId>(tensor->dataType()), tensor->format());
if (tensor->name()) set_name(tensor->name()->str());
if (tensor->data() != nullptr) data_.Copy(tensor->data()->data(), tensor->data()->size());
}
namespace NN {
Node *Input(const std::vector<int> &dims, TypeId data_type, int fmt) {
auto i = new (std::nothrow) InputM(dims, data_type, fmt);
if (i == nullptr) {
MS_LOG(ERROR) << "Cannot allocate input expression ";
return nullptr;
}
return i;
}
Net *Sequential() {
auto s = new (std::nothrow) mindspore::lite::Sequential();
if (s == nullptr) {
MS_LOG(ERROR) << "Cannot allocate sequential expression";
return nullptr;
}
return s;
}
}; // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_
#include <vector>
#include <string>
#include <set>
#include "include/api/net.h"
#include "src/expression/cfg.h"
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class InputM : public Node {
public:
explicit InputM(const schema::Tensor *tensor);
explicit InputM(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC);
Param *data() override { return &data_; }
private:
void SetUp(const std::vector<int> &dims, TypeId data_type, int fmt);
Param data_;
};
namespace NN {
Node *Conv2D(const ConvConfig &cfg);
Node *Relu();
Node *Dense(const DenseConfig &cfg);
Node *Flatten();
Node *Input(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC);
Node *Add();
Node *Sub();
Node *Div();
Node *Mul();
Node *Neg();
Node *SoftmaxCrossEntropy();
Net *Sequential();
Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg);
Node *Softmax(int axis = -1);
Node *BatchNorm2D(int outp, float momentum = 0.1, float epsilon = 1e-5f);
Node *Sigmoid();
Node *DropOut(float ration = 0.5);
Node *ReLU6();
Node *Reshape(const std::vector<int> &shape);
Node *ReduceMean(bool keep_dims, const std::vector<int> &dims);
Node *ReduceSum(bool keep_dims, const std::vector<int> &dims);
Node *Tile(const std::vector<int> &multiples);
Node *MaxPool2D(const PoolingConfig &cfg);
Node *AvgPool2D(const PoolingConfig &cfg);
} // namespace NN
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_

View File

@ -0,0 +1,133 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/activation.h"
#include "nnacl/fp32/activation_fp32.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
ActM::ActM(schema::ActivationType type) : Node() {
auto op_param = malloc(sizeof(ActivationParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
return;
}
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_Activation);
ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(opParam_.get());
act_param->type_ = type;
act_param->alpha_ = 0.f;
act_param->min_val_ = 0.f;
act_param->max_val_ = 0.f;
}
std::vector<EXPR *> ActM::Grad(EXPR *yt) {
auto actGrad = new (std::nothrow) ActGradM(this);
if (actGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate activation grad node";
return {};
}
PushOp(actGrad);
auto param = reinterpret_cast<ActivationParameter *>(actGrad->OpParam());
EXPR *ag = nullptr;
actGrad->expr()->SetSize(C2NUM);
if ((param->type_ == schema::ActivationType_SIGMOID) || (param->type_ == schema::ActivationType_TANH)) {
ag = (*actGrad)({output(0), yt}).front();
} else if ((param->type_ == schema::ActivationType_HSWISH) || (param->type_ == schema::ActivationType_HSIGMOID) ||
(param->type_ == schema::ActivationType_RELU6)) {
ag = (*actGrad)({yt, input(0)}).front();
} else if (param->type_ == schema::ActivationType_GELU) {
actGrad->expr()->SetSize(C3NUM);
ag = (*actGrad)({yt, input(0), output(0)}).front();
} else {
ag = (*actGrad)({yt, output(0)}).front();
}
std::vector<EXPR *> res = {ag};
return res;
}
int ActM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam());
auto prim = new (std::nothrow) schema::ActivationT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate activation primitive";
return RET_ERROR;
}
prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_);
prim->alpha = act_param->alpha_;
prim->min_val = act_param->min_val_;
prim->max_val = act_param->max_val_;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Activation, ReturnNode<ActM>);
ActGradM::ActGradM(Node *node) {
CloneOpParam<ActivationParameter>(node->OpParam());
set_primitive(schema::PrimitiveType_ActivationGrad);
set_name(node->name() + "/" + kGradName + "/actGrad");
}
std::vector<EXPR *> ActGradM::Grad(EXPR *yt) { return {}; }
int ActGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam());
auto prim = new (std::nothrow) schema::ActivationGradT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate activation grad primitive";
return RET_ERROR;
}
prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_);
prim->alpha = act_param->alpha_;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg regGrad(schema::PrimitiveType_ActivationGrad, ReturnNode<ActGradM>);
namespace NN {
Node *ReLU6() {
auto r = new (std::nothrow) ActM(schema::ActivationType_RELU6);
if (r == nullptr) {
MS_LOG(ERROR) << "Cannot allocate relu6";
return nullptr;
}
r->set_name(Node::UniqueName("ReLU6"));
return r;
}
Node *Sigmoid() {
auto s = new (std::nothrow) ActM(schema::ActivationType_SIGMOID);
if (s == nullptr) {
MS_LOG(ERROR) << "Cannot allocate sigmoid";
return nullptr;
}
s->set_name(Node::UniqueName("Sigmoid"));
return s;
}
Node *Relu() {
auto r = new (std::nothrow) ActM(schema::ActivationType_RELU);
if (r == nullptr) {
MS_LOG(ERROR) << "Cannot allocate relu";
return nullptr;
}
r->set_name(r->UniqueName("Relu"));
return r;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class ActM : public Node {
public:
ActM() = default;
explicit ActM(schema::ActivationType type);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class ActGradM : public Node {
public:
ActGradM() : Node() {} // for Import
explicit ActGradM(Node *act); // for Grad
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_

View File

@ -0,0 +1,142 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/adam.h"
#include <memory>
#include <set>
#include <utility>
#include "src/expression/ops.h"
#include "src/expression/ops/assign.h"
#include "src/expression/ops/arithmetic.h"
#include "nnacl/fp32_grad/optimizer.h"
#include "include/api/net.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace NN {
Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg) {
auto lite_node = lite::NN::Adam(std::move(learn->set_), cfg);
return NodeImpl::Connect(lite_node);
}
} // namespace NN
namespace lite {
std::vector<EXPR *> AdamM::Clone(EXPR *grad, EXPR *weight) {
auto adam = new (std::nothrow) AdamM();
if (adam == nullptr) {
MS_LOG(ERROR) << "Cannot allocate adam";
return {};
}
adam->set_name("optimizer-Adam");
adam->CloneOpParam<AdamParameter>(OpParam());
adam->update_name(weight->node()->name());
adam->set_primitive(primitive());
adam->expr()->SetSize(C10NUM);
// setup weight and momentum
adam->expr()->set_params(C0NUM, weight);
auto dims = grad->dims();
auto m = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "m");
adam->expr()->set_params(C1NUM, m);
auto v = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "v");
adam->expr()->set_params(C2NUM, v);
// copy parameters
for (int i = C3NUM; i < C9NUM; i++) {
adam->expr()->set_params(i, this->input(i));
}
adam->expr()->set_params(C9NUM, grad);
return (*adam)(adam->inputs());
}
AdamM::AdamM(std::set<Node *> &&learn, const AdamConfig &cfg) {
auto op_param = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
return;
}
AssignLearn(std::move(learn));
memset(op_param, 0, sizeof(AdamParameter));
op_param->use_nesterov_ = cfg.use_nesterov_;
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_Adam);
set_name("optimizer-Adam");
// Adam Network
expr()->SetSize(C10NUM);
auto assign1 = new (std::nothrow) AssignM(0);
if (assign1 == nullptr) {
MS_LOG(ERROR) << "Cannot allocate assign";
return;
}
PushOp(assign1);
auto assign2 = new (std::nothrow) AssignM(0);
if (assign2 == nullptr) {
MS_LOG(ERROR) << "Cannot allocate assign";
return;
}
PushOp(assign2);
auto mul1 = NN::Mul();
if (mul1 == nullptr) {
MS_LOG(ERROR) << "Cannot allocate mul";
return;
}
PushOp(mul1);
auto mul2 = NN::Mul();
if (mul2 == nullptr) {
MS_LOG(ERROR) << "Cannot allocate mul";
return;
}
PushOp(mul2);
auto tmp = 1.0f;
mul1->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-power", &tmp);
mul1->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-data", &cfg.beta1_);
auto o1 = (*mul1)({});
assign1_ = (*assign1)({mul1->input(0), o1.front()}).front();
mul2->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-power", &tmp);
mul2->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-data", &cfg.beta2_);
auto o2 = (*mul2)({});
assign2_ = (*assign2)({mul2->input(0), o2.front()}).front();
expr()->set_params(C3NUM, o1.front());
expr()->set_params(C4NUM, o2.front());
CreateConstTensor(C5NUM, {1}, kNumberTypeFloat32, KHWC, "learning-rate", &cfg.learning_rate_);
CreateConstTensor(C6NUM, {1}, kNumberTypeFloat32, KHWC, "beta1", &cfg.beta1_);
CreateConstTensor(C7NUM, {1}, kNumberTypeFloat32, KHWC, "beta2", &cfg.beta2_);
CreateConstTensor(C8NUM, {1}, kNumberTypeFloat32, KHWC, "epsilon", &cfg.eps_);
}
int AdamM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const AdamParameter *>(OpParam());
auto prim = new (std::nothrow) schema::AdamT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate " << cnode->name;
return RET_ERROR;
}
prim->use_nesterov = param->use_nesterov_;
prim->use_locking = false;
cnode->primitive->value.value = prim;
return RET_OK;
}
namespace NN {
Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg) {
auto a = new (std::nothrow) AdamM(std::move(learn), cfg);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate adam";
return nullptr;
}
return a;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_
#include <vector>
#include <set>
#include <memory>
#include "include/api/net.h"
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class AdamM : public Node {
public:
AdamM() = default;
AdamM(std::set<Node *> &&learn, const AdamConfig &cfg);
std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) override;
void AddNetOutput(std::vector<EXPR *> *output) override {
output->push_back(assign1_);
output->push_back(assign2_);
}
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
private:
float weight_decay_;
float loss_scale_;
EXPR *assign1_{nullptr};
EXPR *assign2_{nullptr};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/addn.h"
namespace mindspore {
namespace lite {
AddN::AddN(int dummy) {
auto op_param = calloc(1, sizeof(OpParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter ";
return;
}
set_name(UniqueName("addN"));
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_AddN);
}
int AddN::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::AddNT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class AddN : public Node {
public:
explicit AddN(int dummy);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_

View File

@ -0,0 +1,223 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/arithmetic.h"
#include <memory>
#include "src/expression/ops/reduce.h"
#include "src/expression/ops/reshape.h"
#include "src/expression/ops_utils.h"
#include "src/expression/ops/arithmetic_self.h"
#include "src/expression/ops.h"
#include "nnacl/arithmetic.h"
#include "src/expression/import.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
// Common Arithmetic Functionality
ArithmeticM::ArithmeticM(schema::PrimitiveType type) : Node() {
auto op_param = malloc(sizeof(ArithmeticParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
return;
}
SetOpParam(op_param);
expr()->SetSize(C2NUM);
set_primitive(type);
}
std::vector<EXPR *> ArithmeticM::binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy) {
auto shape_of_x = x->dims();
auto shape_of_y = y->dims();
auto reduce_dx = dx;
auto reduce_dy = dy;
auto rx = (BroadcastGradientArgs(shape_of_x, shape_of_y))();
if (rx[0].size()) {
auto reduce_sum = NN::ReduceSum(false, rx[0]);
PushOp(reduce_sum);
reduce_dx = (*reduce_sum)({reduce_dx}).front();
auto reshape = NN::Reshape(shape_of_x);
PushOp(reshape);
reduce_dx = (*reshape)({reduce_dx}).front();
}
if (rx[1].size()) {
auto reduce_sum = NN::ReduceSum(false, rx[1]);
PushOp(reduce_sum);
reduce_dy = (*reduce_sum)({reduce_dy}).front();
auto reshape = NN::Reshape(shape_of_y);
PushOp(reshape);
reduce_dy = (*reshape)({reduce_dy}).front();
}
std::vector<EXPR *> out = {reduce_dx, reduce_dy};
return out;
}
// Add Op
AddM::AddM(int dummy) : ArithmeticM(schema::PrimitiveType_AddFusion) { set_name(UniqueName("Add")); }
std::vector<EXPR *> AddM::Grad(EXPR *yt) { return binop_grad_common(input(0), input(1), yt, yt); }
int AddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::AddFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate prim";
return RET_ERROR;
}
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg AddReg(schema::PrimitiveType_AddFusion, ReturnNode<AddM>);
// Div op
DivM::DivM(int dummy) : ArithmeticM(schema::PrimitiveType_RealDiv) { set_name(UniqueName("RealDiv")); }
std::vector<EXPR *> DivM::Grad(EXPR *yt) {
auto x = input(0);
auto y = input(1);
auto o = output(0);
auto div_op = NN::Div();
if (div_op == nullptr) {
MS_LOG(ERROR) << "Cannot allocate div_op";
return {};
}
PushOp(div_op);
auto neg_op = NN::Neg();
if (neg_op == nullptr) {
MS_LOG(ERROR) << "Cannot allocate neg_op";
return {};
}
PushOp(neg_op);
auto mul_op = NN::Mul();
if (mul_op == nullptr) {
MS_LOG(ERROR) << "Cannot allocate mul_op";
return {};
}
PushOp(mul_op);
auto bc_x = (*div_op)({yt, y}).front();
auto bc_y = (*neg_op)((*mul_op)({bc_x, o})).front();
return binop_grad_common(x, y, bc_x, bc_y);
}
int DivM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::RealDivT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg DivReg(schema::PrimitiveType_DivFusion, ReturnNode<DivM>);
// Mul op
MulM::MulM(int dummy) : ArithmeticM(schema::PrimitiveType_MulFusion) { set_name(UniqueName("Mul")); }
std::vector<EXPR *> MulM::Grad(EXPR *yt) {
auto mul_dx = NN::Mul();
if (mul_dx == nullptr) {
MS_LOG(ERROR) << "Cannot allocate mul dx";
return {};
}
PushOp(mul_dx);
auto mul_dy = NN::Mul();
if (mul_dy == nullptr) {
MS_LOG(ERROR) << "Cannot allocate mul_dy";
return {};
}
PushOp(mul_dy);
auto x = input(0);
auto y = input(1);
auto bc_dx = (*mul_dx)({y, yt}).front();
auto bc_dy = (*mul_dy)({x, yt}).front();
return binop_grad_common(x, y, bc_dx, bc_dy);
}
int MulM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::MulFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate prim";
return RET_ERROR;
}
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg MulReg(schema::PrimitiveType_MulFusion, ReturnNode<MulM>);
// Sub op
SubM::SubM(int dummy) : ArithmeticM(schema::PrimitiveType_SubFusion) { set_name(UniqueName("Sub")); }
std::vector<EXPR *> SubM::Grad(EXPR *yt) {
auto x = input(0);
auto y = input(1);
auto neg = NN::Neg();
if (neg == nullptr) {
MS_LOG(ERROR) << "Cannot allocate neg";
return {};
}
PushOp(neg);
auto neg_grad = (*neg)({yt}).front();
return binop_grad_common(x, y, yt, neg_grad);
}
int SubM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::SubFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate prim";
return RET_ERROR;
}
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg SubReg(schema::PrimitiveType_SubFusion, ReturnNode<SubM>);
namespace NN {
Node *Add() {
auto a = new (std::nothrow) AddM(0);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate a";
return nullptr;
}
return a;
}
Node *Sub() {
auto a = new (std::nothrow) SubM(0);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate a";
return nullptr;
}
return a;
}
Node *Mul() {
auto a = new (std::nothrow) MulM(0);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate a";
return nullptr;
}
return a;
}
Node *Div() {
auto a = new (std::nothrow) DivM(0);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate a";
return nullptr;
}
return a;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class ArithmeticM : public Node {
public:
ArithmeticM() = default;
explicit ArithmeticM(schema::PrimitiveType type);
protected:
std::vector<EXPR *> binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy);
};
class AddM : public ArithmeticM {
public:
AddM() = default;
explicit AddM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class DivM : public ArithmeticM {
public:
DivM() = default;
explicit DivM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class MulM : public ArithmeticM {
public:
MulM() = default;
explicit MulM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class SubM : public ArithmeticM {
public:
SubM() = default;
explicit SubM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
namespace NN {
Node *Add();
Node *Sub();
Node *Mul();
Node *Div();
} // namespace NN
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/arithmetic_self.h"
#include <memory>
#include "src/expression/ops_utils.h"
#include "src/expression/ops.h"
#include "nnacl/arithmetic_self_parameter.h"
#include "src/expression/import.h"
namespace mindspore {
namespace lite {
// Common Arithmetic Self Functionality
ArithmeticSelfM::ArithmeticSelfM(schema::PrimitiveType type) : Node() {
auto op_param = malloc(sizeof(ArithmeticSelfParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ArithmeticSelfParameter";
return;
}
memset(op_param, 0, sizeof(ArithmeticSelfParameter));
SetOpParam(op_param);
expr()->SetSize(C1NUM);
set_primitive(type);
}
// NEG OP
NegM::NegM(int dummy) : ArithmeticSelfM(schema::PrimitiveType_NegGrad) { set_name(UniqueName("Neg")); }
std::vector<EXPR *> NegM::Grad(EXPR *yt) {
auto grad_neg = new (std::nothrow) NegM(0);
if (grad_neg == nullptr) {
MS_LOG(ERROR) << "Cannot allocate neg gradient";
return {};
}
return (*grad_neg)({yt});
}
int NegM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::NegT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
namespace NN {
Node *Neg() {
auto a = new (std::nothrow) NegM(0);
if (a == nullptr) {
MS_LOG(ERROR) << "Cannot allocate neg node";
return nullptr;
}
return a;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class ArithmeticSelfM : public Node {
public:
explicit ArithmeticSelfM(schema::PrimitiveType type);
protected:
};
class NegM : public ArithmeticSelfM {
public:
explicit NegM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
namespace NN {
Node *Neg();
}
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/assign.h"
#include <memory>
#include "nnacl/reshape_parameter.h"
#include "src/expression/import.h"
namespace mindspore {
namespace lite {
AssignM::AssignM(int dummy) {
auto op_param = calloc(1, sizeof(OpParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ReshapeParameter";
return;
}
expr()->SetSize(C2NUM);
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_Assign);
set_name(UniqueName("Assign"));
}
int AssignM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::AssignT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<AssignM>);
namespace NN {
Node *Assign() {
auto node = new (std::nothrow) AssignM(0);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate node";
return nullptr;
}
node->set_name(Node::UniqueName("Assign"));
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class AssignM : public Node {
public:
AssignM() = default;
explicit AssignM(int dummy);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_

View File

@ -0,0 +1,135 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/batchnorm.h"
#include <memory>
#include "nnacl/batchnorm_parameter.h"
#include "nnacl/fp32_grad/batch_norm.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
BatchNorm2dM::BatchNorm2dM(int outp, float momentum, float epsilon) {
constexpr int bn_inputs = 5;
constexpr int bn_outputs = 5;
auto op_param = calloc(1, sizeof(BatchNormParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate BatchNormParameter";
return;
}
expr()->SetSize(bn_inputs);
set_name(UniqueName("BatchNorm2D"));
auto bn_param = reinterpret_cast<BatchNormParameter *>(op_param);
bn_param->channel_ = outp;
bn_param->momentum_ = momentum;
bn_param->epsilon_ = epsilon;
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_FusedBatchNorm);
std::vector<int> dims = {outp};
auto scale = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "scale");
expr()->set_params(C1NUM, scale);
auto offset = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "offset");
expr()->set_params(C2NUM, offset);
auto mean = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "mean");
expr()->set_params(C3NUM, mean);
auto var = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "var");
expr()->set_params(C4NUM, var);
SetOutputs(bn_outputs);
SetLearn();
}
BatchNorm2dGradM::BatchNorm2dGradM(BatchNorm2dM *bn_node) : Node() {
auto op_param = calloc(1, sizeof(BNGradParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate BNGradParameter";
return;
}
expr()->SetSize(C6NUM);
set_name(bn_node->name() + "/" + kGradName + "/bnGrad");
auto bn_grad_param = reinterpret_cast<BNGradParameter *>(op_param);
auto bn_param = reinterpret_cast<BatchNormParameter *>(bn_node->OpParam());
bn_param->is_training_ = true;
bn_grad_param->epsilon_ = bn_param->epsilon_;
bn_grad_param->is_training_ = true;
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_BatchNormGrad);
EXPR e(this);
e.SetSize(0);
// Dgamma
expr_.emplace_back(e);
// Doffset
expr_.emplace_back(e);
}
std::vector<EXPR *> BatchNorm2dM::Grad(EXPR *yt) {
auto bn_grad_node = new (std::nothrow) BatchNorm2dGradM(this);
if (bn_grad_node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate batchnorm grad";
return {};
}
PushOp(bn_grad_node);
auto bn_grad = (*bn_grad_node)({yt, input(0), output(1), output(3), output(4), output(2)});
return bn_grad;
}
int BatchNorm2dM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto bn_param = reinterpret_cast<const BatchNormParameter *>(OpParam());
auto prim = new (std::nothrow) schema::FusedBatchNormT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->epsilon = bn_param->epsilon_;
prim->momentum = bn_param->momentum_;
prim->mode = (bn_param->is_training_ == false) ? 0 : 1;
cnode->primitive->value.value = prim;
return RET_OK;
}
void BatchNorm2dM::SetLearn() {
AddLearn(input(C1NUM)->node());
AddLearn(input(C2NUM)->node());
}
int BatchNorm2dGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const BNGradParameter *>(OpParam());
auto prim = new (std::nothrow) schema::BatchNormGradT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->epsilon = param->epsilon_;
prim->is_training = param->is_training_;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_FusedBatchNorm, ReturnNode<BatchNorm2dM>);
namespace NN {
Node *BatchNorm2D(int outp, float momentum, float epsilon) {
auto node = new (std::nothrow) BatchNorm2dM(outp, momentum, epsilon);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate node";
return nullptr;
}
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class BatchNorm2dM : public Node {
public:
BatchNorm2dM() = default;
BatchNorm2dM(int outp, float momentum, float epsilon);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
void SetLearn() override;
};
class BatchNorm2dGradM : public Node {
public:
explicit BatchNorm2dGradM(BatchNorm2dM *bn_node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_

View File

@ -0,0 +1,93 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/biasadd.h"
#include "src/expression/ops/transpose.h"
#include "nnacl/arithmetic.h"
#include "src/expression/import.h"
namespace mindspore {
namespace lite {
BiasAddM::BiasAddM(Format data_format) {
auto op_param = calloc(1, sizeof(ArithmeticParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ConvParameter";
return;
}
auto bias_param = reinterpret_cast<ArithmeticParameter *>(op_param);
SetOpParam(bias_param);
set_primitive(schema::PrimitiveType_BiasAdd);
}
std::vector<EXPR *> BiasAddM::construct(const std::vector<EXPR *> &inputs) {
auto x = Node::construct(inputs);
AddLearn(inputs.at(C1NUM)->node());
return x;
}
void BiasAddM::SetLearn() { AddLearn(input(C1NUM)->node()); }
std::vector<EXPR *> BiasAddM::Grad(EXPR *yt) {
auto in = yt;
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
in = TransposeM::TransposeCHW2HWC(yt);
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
PushOp(in->node());
}
auto grad_node = new (std::nothrow) BiasAddGradM(*this);
if (grad_node == nullptr) {
MS_LOG(ERROR) << "Cannon allocate Bias Grad";
return {};
}
PushOp(grad_node);
auto bias_grad = (*grad_node)({in});
return {in, bias_grad.front()};
}
int BiasAddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::BiasAddT;
if (prim == nullptr) {
MS_LOG(ERROR) << "cannot allocate prim";
return RET_ERROR;
}
prim->format = static_cast<schema::Format>(KHWC);
cnode->primitive->value.value = prim;
return RET_OK;
}
BiasAddGradM::BiasAddGradM(const BiasAddM &bias) {
auto op_param = calloc(1, sizeof(OpParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate op_param";
}
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_BiasAddGrad);
}
int BiasAddGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::BiasAddGradT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_BiasAdd, ReturnNode<BiasAddM>);
namespace NN {}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_
#include <vector>
#include <memory>
#include "src/expression/net.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class BiasAddM : public Node {
public:
BiasAddM() = default;
explicit BiasAddM(Format data_format);
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
std::vector<EXPR *> Grad(EXPR *yt) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
void SetLearn() override;
};
class BiasAddGradM : public Node {
public:
explicit BiasAddGradM(const BiasAddM &bias);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_

View File

@ -0,0 +1,241 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/conv.h"
#include <memory>
#include "src/expression/ops/biasadd.h"
#include "src/expression/ops/depend.h"
#include "src/expression/ops/transpose.h"
#include "nnacl/conv_parameter.h"
#include "inner/model_generated.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
ConvM::ConvM(const ConvConfig &cfg) : Node() {
auto op_param = calloc(1, sizeof(ConvParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ConvParameter";
return;
}
SetOpParam(op_param);
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(OpParam());
conv_param->input_channel_ = cfg.in_channel_;
conv_param->output_channel_ = cfg.out_channel_;
conv_param->kernel_h_ = cfg.kernel_size_[0];
conv_param->kernel_w_ = cfg.kernel_size_[1];
conv_param->stride_h_ = cfg.stride_[0];
conv_param->stride_w_ = cfg.stride_[1];
auto pad_mode = GetMode(cfg.pad_mode_);
if (pad_mode == -1) {
MS_LOG(ERROR) << "bad pad mode";
return;
}
conv_param->pad_mode_ = static_cast<PadMode>(pad_mode);
conv_param->pad_u_ = cfg.padding_[C0NUM];
conv_param->pad_d_ = cfg.padding_[C1NUM];
conv_param->pad_l_ = cfg.padding_[C2NUM];
conv_param->pad_r_ = cfg.padding_[C3NUM];
conv_param->dilation_h_ = cfg.dilation_[C0NUM];
conv_param->dilation_w_ = cfg.dilation_[C1NUM];
conv_param->group_ = cfg.group_;
conv_param->out_format_ = NHWC;
conv_param->act_type_ = ActType_No;
expr()->SetSize(C2NUM);
set_primitive(schema::PrimitiveType_Conv2DFusion);
set_name(UniqueName("Conv"));
Param::Mode mode = Param::String2Enum(cfg.weight_init_);
std::vector<int> dims = {conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_,
conv_param->input_channel_ / conv_param->group_};
auto w = CreateWeights(dims, kNumberTypeFloat32, KHWC, mode, "weights");
expr()->set_params(C1NUM, w);
if (cfg.has_bias) {
bias_ = new (std::nothrow) BiasAddM(KHWC);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Cannot allocate bias";
return;
}
bias_->update_name(name());
std::vector<int> dim_bias = {conv_param->output_channel_};
wbias_ = CreateWeights(dim_bias, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "weights");
AddLearn(wbias_->node());
PushOp(bias_);
}
SetLearn();
}
std::vector<EXPR *> ConvM::construct(const std::vector<EXPR *> &inputs) {
auto in = inputs;
auto x = in.front();
if (x->format() != NHWC && x->dims().size() == C4NUM) {
x = TransposeM::TransposeCHW2HWC(x);
x->node()->set_name(name() + "/" + x->node()->name());
PushOp(x->node());
in.at(0) = x;
}
auto y = Node::construct(in);
if (bias_ != nullptr) {
y = (*bias_)({y.front(), wbias_});
}
return y;
}
void ConvM::SetLearn() { AddLearn(input(C1NUM)->node()); }
int ConvM::GetMode(std::string mode) {
const std::vector<std::string> list = {"pad", "same", "valid"};
auto itr = std::find(list.begin(), list.end(), mode);
if (itr == list.end()) {
MS_LOG(ERROR) << "illegal mode" << mode;
return -1;
}
return std::distance(list.begin(), itr);
}
std::vector<EXPR *> ConvM::Grad(EXPR *yt) {
// Generate Input Grad
EXPR *in = yt;
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
in = TransposeM::TransposeCHW2HWC(yt);
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
PushOp(in->node());
}
auto inGrad = new (std::nothrow) ConvInputGradM(this);
if (inGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate convolution input grad";
return {};
}
PushOp(inGrad);
auto ig = (*inGrad)({in, input(1), inGrad->input(2)});
// Execution Control Flow !
auto depend = NN::Depend();
if (depend == nullptr) {
MS_LOG(ERROR) << "Cannot allocate depend";
return {};
}
PushOp(depend);
depend->update_name(name());
auto de = (*depend)({inGrad->expr()});
// Generate Filter Grad
auto filterGrad = new (std::nothrow) ConvFilterGradM(this);
if (filterGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate convolution filter grad";
return {};
}
PushOp(filterGrad);
filterGrad->update_name(name());
auto fg = (*filterGrad)({in, input(0), filterGrad->input(2), de[0]});
std::vector<EXPR *> res = {ig[0], fg[0]};
return res;
}
int ConvM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
auto prim = new (std::nothrow) schema::Conv2DFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
prim->format = static_cast<schema::Format>(conv_param->out_format_);
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
prim->out_channel = conv_param->output_channel_;
prim->in_channel = conv_param->input_channel_;
prim->group = conv_param->group_;
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
prim->mode = 1;
cnode->primitive->value.value = prim;
return RET_OK;
}
ConvInputGradM::ConvInputGradM(ConvM *conv_node) : Node() {
CloneOpParam<ConvParameter>(conv_node->OpParam());
set_primitive(schema::PrimitiveType_Conv2DBackpropInputFusion);
set_name(kGradName + "/conv2DBackpropInput");
expr()->SetSize(C3NUM);
auto const x = conv_node->input(0);
CreateConstTensor(C2NUM, {static_cast<int32_t>(x->dims().size())}, kNumberTypeInt32, KHWC, "shape", x->dims().data());
}
int ConvInputGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
auto prim = new (std::nothrow) schema::Conv2DBackpropInputFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
prim->format = static_cast<schema::Format>(conv_param->out_format_);
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
prim->out_channel = conv_param->output_channel_;
prim->in_channel = conv_param->input_channel_;
prim->group = conv_param->group_;
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
cnode->primitive->value.value = prim;
return RET_OK;
}
ConvFilterGradM::ConvFilterGradM(ConvM *conv_node) : Node() {
CloneOpParam<ConvParameter>(conv_node->OpParam());
set_primitive(schema::PrimitiveType_Conv2DBackpropFilterFusion);
set_name(kGradName + "/conv2DBackpropFilter");
expr()->SetSize(C4NUM);
auto w = conv_node->input(1);
CreateConstTensor(C2NUM, {static_cast<int32_t>(w->dims().size())}, kNumberTypeInt32, KHWC, "shape", w->dims().data());
}
int ConvFilterGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
auto prim = new (std::nothrow) schema::Conv2DBackpropFilterFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
prim->format = static_cast<schema::Format>(conv_param->out_format_);
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
prim->out_channel = conv_param->output_channel_;
prim->in_channel = conv_param->input_channel_;
prim->group = conv_param->group_;
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Conv2DFusion, ReturnNode<ConvM>);
namespace NN {
Node *Conv2D(const ConvConfig &cfg) {
auto c = new (std::nothrow) ConvM(cfg);
if (c == nullptr) {
MS_LOG(ERROR) << "Cannot allocate Convolution object";
return nullptr;
}
return c;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_
#include <vector>
#include <memory>
#include <string>
#include "src/expression/cfg.h"
#include "src/expression/node.h"
namespace mindspore {
namespace lite {
class ConvM : public Node {
public:
ConvM() = default;
explicit ConvM(const ConvConfig &cfg);
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
Param *weight() override { return input(1)->node()->data(); }
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
void SetLearn() override;
private:
int GetMode(std::string mode);
Node *bias_ = nullptr;
EXPR *wbias_ = nullptr;
};
class ConvInputGradM : public Node {
public:
explicit ConvInputGradM(ConvM *conv_node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class ConvFilterGradM : public Node {
public:
explicit ConvFilterGradM(ConvM *conv_node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_

View File

@ -0,0 +1,151 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/dense.h"
#include <memory>
#include "include/api/cfg.h"
#include "src/expression/ops/biasadd.h"
#include "src/expression/ops/depend.h"
#include "src/expression/ops.h"
#include "nnacl/matmul_parameter.h"
#include "src/expression/import.h"
#include "inner/model_generated.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
DenseM::DenseM(const DenseConfig &cfg) : Node() {
auto op_param = calloc(1, sizeof(MatMulParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate MatMulParameter";
return;
}
set_name(UniqueName("Dense"));
SetOpParam(op_param);
expr()->SetSize(C2NUM);
set_primitive(schema::PrimitiveType_MatMulFusion);
auto param = reinterpret_cast<MatMulParameter *>(opParam_.get());
param->row_ = cfg.out_channels_;
param->col_ = cfg.in_channels_;
param->a_transpose_ = false;
param->b_transpose_ = true;
std::vector<int> dims = {param->row_, param->col_};
auto w = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::NORMAL, "weights");
expr()->set_params(C1NUM, w);
if (cfg.has_bias_) {
wbias_ = CreateWeights({cfg.out_channels_}, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "bias_weights");
bias_ = new (std::nothrow) BiasAddM(KHWC);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Cannot allocate bias";
return;
}
bias_->update_name(name());
AddLearn(wbias_->node());
PushOp(bias_);
}
SetLearn();
}
std::vector<EXPR *> DenseM::construct(const std::vector<EXPR *> &inputs) {
auto x = Node::construct(inputs);
if (bias_ != nullptr) {
x = (*bias_)({x.front(), wbias_});
}
return x;
}
std::vector<EXPR *> DenseM::Grad(EXPR *yt) {
auto src_param = reinterpret_cast<MatMulParameter *>(opParam_.get());
bool ta = src_param->a_transpose_;
bool tb = src_param->b_transpose_;
// dx grad op
auto dxGrad = new (std::nothrow) DenseM();
if (dxGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate dxGrad ";
return {};
}
PushOp(dxGrad);
dxGrad->CloneOpParam<MatMulParameter>(opParam_);
dxGrad->set_primitive(schema::PrimitiveType_MatMulFusion);
auto dxGradParam = reinterpret_cast<MatMulParameter *>(dxGrad->OpParam());
dxGradParam->a_transpose_ = (ta && tb);
dxGradParam->b_transpose_ = (ta || !tb);
dxGrad->set_name(name() + kGradName + "/dxGrad");
EXPR *dx = nullptr;
if (ta) {
dx = (*dxGrad)({input(1), yt}).front();
} else {
dx = (*dxGrad)({yt, input(1)}).front();
}
// Control execution flow
auto depend = NN::Depend();
if (depend == nullptr) {
MS_LOG(ERROR) << "Cannot allocate depend ";
return {};
}
PushOp(depend);
auto de = (*depend)({dxGrad->expr()}).front();
// dw grad op
auto dwGrad = new (std::nothrow) DenseM();
if (dwGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate dwGrad ";
return {};
}
PushOp(dwGrad);
dwGrad->CloneOpParam<MatMulParameter>(opParam_);
dwGrad->set_primitive(schema::PrimitiveType_MatMulFusion);
auto dwGradParam = reinterpret_cast<MatMulParameter *>(dwGrad->OpParam());
dwGradParam->a_transpose_ = (!ta || tb);
dwGradParam->b_transpose_ = ta && tb;
dwGrad->set_name(name() + kGradName + "/dwGrad");
EXPR *dw = nullptr;
if (tb) {
dw = (*dwGrad)({yt, input(0), de}).front();
} else {
dw = (*dwGrad)({input(0), yt, de}).front();
}
return {dx, dw};
}
int DenseM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto dense_param = reinterpret_cast<const MatMulParameter *>(OpParam());
auto prim = new (std::nothrow) schema::MatMulFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->transpose_a = dense_param->a_transpose_;
prim->transpose_b = dense_param->b_transpose_;
cnode->primitive->value.value = prim;
return RET_OK;
}
void DenseM::SetLearn() { AddLearn(input(C1NUM)->node()); }
static ImportReg reg(schema::PrimitiveType_MatMulFusion, ReturnNode<DenseM>);
namespace NN {
Node *Dense(const DenseConfig &cfg) {
auto l = new (std::nothrow) DenseM(cfg);
if (l == nullptr) {
MS_LOG(ERROR) << "Cannot allocate Dense object";
}
return l;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
#include "src/expression/cfg.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class DenseM : public Node {
public:
DenseM() = default;
explicit DenseM(const DenseConfig &cfg);
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
void SetLearn() override;
private:
Param *weight() override { return input(1)->node()->data(); }
Node *bias_{nullptr};
EXPR *wbias_{nullptr};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/depend.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
DependM::DependM() : Node() {
auto param = calloc(1, sizeof(OpParameter));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
SetOpParam(param);
set_primitive(schema::PrimitiveType_Depend);
set_name(UniqueName("Depend"));
}
namespace NN {
Node *Depend() {
auto d = new (std::nothrow) DependM();
if (d == nullptr) {
MS_LOG(ERROR) << "Cannot allocate depend object";
return nullptr;
}
return d;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_
#include <vector>
#include "src/expression/node.h"
#include "src/expression/ops.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class DependM : public Node {
public:
DependM();
};
namespace NN {
Node *Depend();
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_

View File

@ -0,0 +1,91 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/dropout.h"
#include <vector>
#include "nnacl/fp32_grad/dropout_parameter.h"
#include "inner/model_generated.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
DropOutM::DropOutM(float ratio) {
auto param = reinterpret_cast<DropoutParameter *>(calloc(1, sizeof(DropoutParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
param->ratio_ = ratio;
SetOpParam(param);
set_primitive(schema::PrimitiveType_Dropout);
set_name(UniqueName("DropOut"));
}
std::vector<EXPR *> DropOutM::Grad(EXPR *yt) {
auto inGrad = new (std::nothrow) DropOutGradM(this);
if (inGrad == nullptr) {
MS_LOG(ERROR) << "Cannot allocate drop grad";
return {};
}
return (*inGrad)({yt, expr()});
}
int DropOutM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const DropoutParameter *>(OpParam());
auto prim = new (std::nothrow) schema::DropoutT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->keep_prob = param->ratio_;
cnode->primitive->value.value = prim;
return RET_OK;
}
DropOutGradM::DropOutGradM(DropOutM *node) {
CloneOpParam<DropoutParameter>(node->OpParam());
set_primitive(schema::PrimitiveType_DropoutGrad);
set_name(kGradName + "/DropOutGrad");
}
int DropOutGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const DropoutParameter *>(OpParam());
auto prim = new (std::nothrow) schema::DropoutGradT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->keep_prob = param->ratio_;
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Dropout, ReturnNode<DropOutM>);
namespace NN {
Node *DropOut(float ratio) {
auto node = new (std::nothrow) DropOutM(ratio);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate dropout node";
return nullptr;
}
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
namespace mindspore {
namespace lite {
class DropOutM : public Node {
public:
DropOutM() = default;
explicit DropOutM(float ratio);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class DropOutGradM : public Node {
public:
explicit DropOutGradM(DropOutM *node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_

View File

@ -0,0 +1,71 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/flatten.h"
#include <vector>
#include "inner/model_generated.h"
#include "src/expression/import.h"
#include "src/expression/ops.h"
#include "src/cxx_api/expression/node_impl.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
FlattenM::FlattenM(int dummy) {
auto param = reinterpret_cast<OpParameter *>(calloc(C1NUM, sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
SetOpParam(param);
set_primitive(schema::PrimitiveType_Flatten);
set_name(UniqueName("Flatten"));
}
std::vector<EXPR *> FlattenM::construct(const std::vector<EXPR *> &inputs) {
auto in = inputs;
auto y = Node::construct(in);
return y;
}
std::vector<EXPR *> FlattenM::Grad(EXPR *yt) {
auto shape_of_x = input(0)->dims();
auto reshape = NN::Reshape(shape_of_x);
PushOp(reshape);
return (*reshape)({yt});
}
int FlattenM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::DropoutT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Flatten, ReturnNode<FlattenM>);
namespace NN {
Node *Flatten() {
auto node = new (std::nothrow) FlattenM(0);
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_
#include <memory>
#include <vector>
#include "src/expression/node.h"
namespace mindspore {
namespace lite {
class FlattenM : public Node {
public:
FlattenM() = default;
explicit FlattenM(int dummy);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_

View File

@ -0,0 +1,215 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/pooling.h"
#include "src/expression/ops.h"
#include "src/expression/import.h"
#include "src/cxx_api/expression/node_impl.h"
#include "src/expression/ops/transpose.h"
namespace mindspore {
namespace lite {
PoolingM::PoolingM(const PoolingConfig &cfg) : Node() {
auto op_param = calloc(1, sizeof(PoolingParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate PoolingParameter";
return;
}
SetOpParam(op_param);
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(OpParam());
pool_param->window_h_ = cfg.kernel_size_[0];
pool_param->window_w_ = cfg.kernel_size_[1];
pool_param->stride_h_ = cfg.stride_[0];
pool_param->stride_w_ = cfg.stride_[1];
auto pad_mode = GetMode(cfg.pad_mode_);
if (pad_mode == -1) {
MS_LOG(ERROR) << "bad pad mode";
return;
}
pool_param->pad_mode_ = static_cast<PadMode>(pad_mode + Pad_pad);
pool_param->round_mode_ = RoundMode_Floor;
pool_param->act_type_ = ActType_No;
}
std::vector<EXPR *> PoolingM::construct(const std::vector<EXPR *> &inputs) {
auto in = inputs;
auto x = in.front();
if (x->format() != NHWC && x->dims().size() == C4NUM) {
x = TransposeM::TransposeCHW2HWC(x);
x->node()->set_name(name() + "/" + x->node()->name());
PushOp(x->node());
in.at(0) = x;
}
auto y = Node::construct(in);
return y;
}
int PoolingM::GetMode(std::string mode) {
const std::vector<std::string> list = {"same", "valid"};
auto itr = std::find(list.begin(), list.end(), mode);
if (itr == list.end()) {
MS_LOG(ERROR) << "illegal mode" << mode;
return -1;
}
return std::distance(list.begin(), itr);
}
void PoolingM::UpdateRoundMode(const PoolingParameter *param, schema::RoundMode *round_mode) {
switch (param->round_mode_) {
case RoundMode_Floor:
*round_mode = schema::RoundMode_FLOOR;
break;
case RoundMode_Ceil:
*round_mode = schema::RoundMode_CEIL;
break;
default:
*round_mode = schema::RoundMode_FLOOR;
break;
}
}
template <typename T>
int PoolingM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const PoolingParameter *>(OpParam());
auto prim = new (std::nothrow) T;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->kernel_size = {param->window_h_, param->window_w_};
prim->strides = {param->stride_h_, param->stride_w_};
prim->pad = {param->pad_u_, param->pad_d_, param->pad_l_, param->pad_r_};
prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_);
UpdateRoundMode(param, &prim->round_mode);
prim->global = param->global_;
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
cnode->primitive->value.value = prim;
return RET_OK;
}
template <typename T>
int PoolingM::UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const PoolingParameter *>(OpParam());
auto prim = new (std::nothrow) T;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->kernel_size = {param->window_h_, param->window_w_};
prim->strides = {param->stride_h_, param->stride_w_};
prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_);
cnode->primitive->value.value = prim;
return RET_OK;
}
// Max pooling Definition
MaxPoolM::MaxPoolM(const PoolingConfig &cfg) : PoolingM(cfg) {
auto param = reinterpret_cast<PoolingParameter *>(OpParam());
param->pool_mode_ = PoolMode_MaxPool;
set_primitive(schema::PrimitiveType_MaxPoolFusion);
set_name(UniqueName("MaxPool"));
}
int MaxPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
return PoolingM::UnPopulate<schema::MaxPoolFusionT>(cnode);
}
std::vector<EXPR *> MaxPoolM::Grad(EXPR *yt) {
auto in = yt;
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
in = TransposeM::TransposeCHW2HWC(yt);
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
PushOp(in->node());
}
auto pool_grad = new (std::nothrow) MaxPoolGradM(this);
PushOp(pool_grad);
return (*pool_grad)({input(0), output(0), in});
}
static ImportReg maxPoolReg(schema::PrimitiveType_MaxPoolFusion, ReturnNode<MaxPoolM>);
// Avg pooling Definition
AvgPoolM::AvgPoolM(const PoolingConfig &cfg) : PoolingM(cfg) {
auto param = reinterpret_cast<PoolingParameter *>(OpParam());
param->pool_mode_ = PoolMode_AvgPool;
set_primitive(schema::PrimitiveType_AvgPoolFusion);
set_name(UniqueName("AvgPool"));
}
int AvgPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
return PoolingM::UnPopulate<schema::AvgPoolFusionT>(cnode);
}
std::vector<EXPR *> AvgPoolM::Grad(EXPR *yt) {
auto in = yt;
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
in = TransposeM::TransposeCHW2HWC(yt);
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
PushOp(in->node());
}
auto pool_grad = new (std::nothrow) AvgPoolGradM(this);
PushOp(pool_grad);
return (*pool_grad)({input(0), output(0), in});
}
static ImportReg avgPoolReg(schema::PrimitiveType_AvgPoolFusion, ReturnNode<AvgPoolM>);
// Max Pool Grad Definition
MaxPoolGradM::MaxPoolGradM(MaxPoolM *node) {
Node();
CloneOpParam<PoolingParameter>(node->OpParam());
set_primitive(schema::PrimitiveType_MaxPoolGrad);
set_name(kGradName + "/" + node->name() + "/MaxPoolGrad");
}
int MaxPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
return PoolingM::UnPopulateGrad<schema::MaxPoolGradT>(cnode);
}
// Avg Pool Grad Definition
AvgPoolGradM::AvgPoolGradM(AvgPoolM *node) {
Node();
CloneOpParam<PoolingParameter>(node->OpParam());
set_primitive(schema::PrimitiveType_AvgPoolGrad);
set_name(kGradName + "/" + node->name() + "/AvgPoolGrad");
}
int AvgPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
return PoolingM::UnPopulateGrad<schema::AvgPoolGradT>(cnode);
}
namespace NN {
Node *MaxPool2D(const PoolingConfig &cfg) {
auto c = new (std::nothrow) MaxPoolM(cfg);
if (c == nullptr) {
MS_LOG(ERROR) << "Cannot allocate max pool object";
return nullptr;
}
return c;
}
Node *AvgPool2D(const PoolingConfig &cfg) {
auto c = new (std::nothrow) AvgPoolM(cfg);
if (c == nullptr) {
MS_LOG(ERROR) << "Cannot allocate average pool object";
return nullptr;
}
return c;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_
#include <vector>
#include <string>
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
#include "src/expression/cfg.h"
#include "nnacl/pooling_parameter.h"
namespace mindspore {
namespace lite {
class PoolingM : public Node {
public:
PoolingM() = default;
explicit PoolingM(const PoolingConfig &cfg);
template <typename T>
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode);
template <typename T>
int UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode);
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs);
private:
void UpdateRoundMode(const PoolingParameter *param, enum schema::RoundMode *round_mode);
int GetMode(std::string mode);
};
class MaxPoolM : public PoolingM {
public:
MaxPoolM() = default;
explicit MaxPoolM(const PoolingConfig &cfg);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class AvgPoolM : public PoolingM {
public:
AvgPoolM() = default;
explicit AvgPoolM(const PoolingConfig &cfg);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class MaxPoolGradM : public PoolingM {
public:
explicit MaxPoolGradM(MaxPoolM *node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
class AvgPoolGradM : public PoolingM {
public:
explicit AvgPoolGradM(AvgPoolM *node);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_

View File

@ -0,0 +1,126 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/reduce.h"
#include <functional>
#include "src/expression/ops/tile.h"
#include "src/expression/ops/reshape.h"
#include "src/expression/ops/arithmetic.h"
#include "src/expression/ops.h"
#include "src/expression/ops_utils.h"
#include "src/expression/import.h"
#include "nnacl/reduce_parameter.h"
#include "src/cxx_api/expression/node_impl.h"
namespace mindspore {
namespace lite {
ReduceM::ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis) : Node() {
expr()->SetSize(C2NUM);
ReduceParameter *param = reinterpret_cast<ReduceParameter *>(calloc(1, sizeof(ReduceParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
param->mode_ = mode;
param->keep_dims_ = keep_dims;
param->reduce_to_end_ = false;
param->coeff = 1.f;
SetOpParam(param);
set_name(UniqueName("Reduce"));
set_primitive(schema::PrimitiveType_ReduceFusion);
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(axis.size())}, kNumberTypeInt32, KHWC, "axis", axis.data());
}
int ReduceM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam());
auto prim = new (std::nothrow) schema::ReduceFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->keep_dims = reduce_param->keep_dims_;
prim->mode = static_cast<schema::ReduceMode>(reduce_param->mode_);
prim->coeff = reduce_param->coeff;
prim->reduce_to_end = reduce_param->reduce_to_end_;
cnode->primitive->value.value = prim;
return RET_OK;
}
std::vector<EXPR *> ReduceM::Grad(EXPR *yt) {
auto shape_of_x = input(0)->dims();
std::vector<int> shape_of_axis;
auto data = input(1)->node()->data()->data().data();
int size = input(1)->dims().at(0);
auto int_data = reinterpret_cast<int *>(data);
for (int i = 0; i < size; i++) {
shape_of_axis.push_back(int_data[i]);
}
// assume no dynamic shape
ShapeReduce reduce_shape;
auto output_shape_kept_dims = ShapeReduce()(shape_of_x, shape_of_axis);
auto tile_scaling = VectorDiv()(shape_of_x, output_shape_kept_dims);
auto reshape = NN::Reshape(output_shape_kept_dims);
PushOp(reshape);
reshape->set_name(name() + "/reshape");
auto g = (*reshape)({yt}).front();
auto tile = NN::Tile(tile_scaling);
PushOp(tile);
tile->set_name(name() + "/tile");
auto sum_grad = (*tile)({g}).front();
auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam());
if (reduce_param->mode_ == schema::ReduceMode_ReduceSum) {
return {sum_grad};
} else if (reduce_param->mode_ == schema::ReduceMode_ReduceMean) {
auto shape_of_y = output(0)->dims();
auto shape_x_mul = std::accumulate(shape_of_x.begin(), shape_of_x.end(), 1, std::multiplies<int>());
auto shape_y_mul = std::accumulate(shape_of_y.begin(), shape_of_y.end(), 1, std::multiplies<int>());
auto div_shape = static_cast<float>(shape_x_mul) / static_cast<float>(shape_y_mul);
auto div_op = NN::Div();
PushOp(div_op);
auto d = div_op->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "div_shape", &div_shape);
auto dx = (*div_op)({sum_grad, d->expr()});
return dx;
} else {
return {};
}
}
static ImportReg reg(schema::PrimitiveType_ReduceFusion, ReturnNode<ReduceM>);
namespace NN {
Node *ReduceSum(bool keep_dims, const std::vector<int> &axis) {
auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceSum, keep_dims, axis);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate reduce sum node";
return nullptr;
}
node->set_name(Node::UniqueName("ReduceSum"));
return node;
}
Node *ReduceMean(bool keep_dims, const std::vector<int> &axis) {
auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceMean, keep_dims, axis);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate reduce mean node";
return nullptr;
}
node->set_name(Node::UniqueName("ReduceMean"));
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class ReduceM : public Node {
public:
ReduceM() = default;
ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis);
Param *weight() override { return input(1)->node()->data(); }
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
std::vector<EXPR *> Grad(EXPR *expr) override;
};
namespace NN {
Node *ReduceMean(bool keep_dims, const std::vector<int> &axis);
Node *ReduceSum(bool keep_dims, const std::vector<int> &axis);
} // namespace NN
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/reshape.h"
#include "src/expression/ops.h"
#include "nnacl/reshape_parameter.h"
#include "src/expression/import.h"
namespace mindspore {
namespace lite {
ReshapeM::ReshapeM(const std::vector<int> &shape) : Node() {
auto op_param = calloc(1, sizeof(ReshapeParameter));
if (op_param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ReshapeParameter";
return;
}
set_name(UniqueName("Reshape"));
expr()->SetSize(C2NUM);
SetOpParam(op_param);
set_primitive(schema::PrimitiveType_Reshape);
ReshapeParameter *reshape_param = reinterpret_cast<ReshapeParameter *>(opParam_.get());
reshape_param->shape_dim_ = shape.size();
for (int i = 0; i < reshape_param->shape_dim_; i++) {
reshape_param->shape_[i] = shape.at(i);
}
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(shape.size())}, kNumberTypeInt32, KHWC, "shape", shape.data());
}
int ReshapeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::ReshapeT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
std::vector<EXPR *> ReshapeM::Grad(EXPR *yt) {
auto shape_of_x = input(0)->dims();
auto reshape = NN::Reshape(shape_of_x);
PushOp(reshape);
return (*reshape)({yt});
}
static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<ReshapeM>);
namespace NN {
Node *Reshape(const std::vector<int> &shape) {
auto node = new (std::nothrow) ReshapeM(shape);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate reshape node";
return nullptr;
}
node->set_name(Node::UniqueName("Reshape"));
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class ReshapeM : public Node {
public:
ReshapeM() = default;
explicit ReshapeM(const std::vector<int> &shape);
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
std::vector<EXPR *> Grad(EXPR *expr) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_

View File

@ -0,0 +1,119 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/softmax.h"
#include "nnacl/softmax_parameter.h"
#include "inner/model_generated.h"
#include "src/expression/import.h"
#include "src/expression/ops/reshape.h"
#include "src/expression/ops/reduce.h"
#include "src/expression/ops/arithmetic.h"
#include "src/expression/ops/transpose.h"
#include "src/cxx_api/expression/node_impl.h"
#include "src/expression/ops.h"
namespace mindspore {
namespace lite {
SoftmaxM::SoftmaxM(int axis) {
auto param = reinterpret_cast<SoftmaxParameter *>(calloc(1, sizeof(SoftmaxParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
param->axis_ = axis;
SetOpParam(param);
set_primitive(schema::PrimitiveType_Softmax);
set_name(UniqueName("Softmax"));
}
std::vector<int> SoftmaxM::getTransposeAxis(const std::vector<int> &shape, int axis) {
int rank = shape.size();
if (axis < 0) {
axis += rank;
}
std::vector<int> reverse_axis(rank);
std::iota(reverse_axis.begin(), reverse_axis.end(), 0);
reverse_axis.at(axis) = rank - 1;
reverse_axis.at(rank - 1) = axis;
return reverse_axis;
}
std::vector<EXPR *> SoftmaxM::Grad(EXPR *yt) {
auto x = input(0);
auto out = output(0);
auto shape_of_x = x->dims();
auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam());
auto reverse_axis = getTransposeAxis(shape_of_x, param->axis_);
auto transpose_out = NN::Transpose(reverse_axis);
transpose_out->set_name(kGradName + "/" + name() + "/" + transpose_out->name() + "/out/");
PushOp(transpose_out);
auto y_trn = (*transpose_out)({out}).front();
auto transpose_dout = NN::Transpose(reverse_axis);
transpose_dout->set_name(kGradName + "/" + name() + "/" + transpose_dout->name() + "/dout/");
PushOp(transpose_dout);
auto yt_trn = (*transpose_dout)({yt}).front();
auto mul0 = NN::Mul();
mul0->set_name(kGradName + "/" + name() + "/" + mul0->name() + "0");
PushOp(mul0);
auto tmp0 = (*mul0)({y_trn, yt_trn}).front();
auto sum_func = NN::ReduceSum(true, {-1});
sum_func->set_name(kGradName + "/" + name() + "/" + sum_func->name());
PushOp(sum_func);
auto tmp1 = (*sum_func)({tmp0}).front();
auto sub = NN::Sub();
sub->set_name(kGradName + "/" + name() + "/" + sub->name());
PushOp(sub);
auto tmp2 = (*sub)({yt_trn, tmp1}).front();
auto mul1 = NN::Mul();
mul1->set_name(kGradName + "/" + name() + "/" + mul1->name() + "1");
PushOp(mul1);
auto tmp3 = (*mul1)({y_trn, tmp2});
auto transpose_dx = NN::Transpose(reverse_axis);
transpose_dx->set_name(kGradName + "/" + name() + "/" + transpose_dx->name() + "/dx");
PushOp(transpose_dx);
auto dx = (*transpose_dx)({tmp3});
return dx;
}
int SoftmaxM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam());
auto prim = new (std::nothrow) schema::SoftmaxT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
prim->axis.push_back(param->axis_);
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Softmax, ReturnNode<SoftmaxM>);
namespace NN {
Node *Softmax(int axis) {
auto node = new (std::nothrow) SoftmaxM(axis);
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
namespace mindspore {
namespace lite {
class SoftmaxM : public Node {
public:
SoftmaxM() = default;
explicit SoftmaxM(int axis);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
private:
std::vector<int> getTransposeAxis(const std::vector<int> &shape, int axis);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_

View File

@ -0,0 +1,93 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/softmaxCE.h"
#include "include/api/net.h"
#include "src/cxx_api/expression/node_impl.h"
#include "src/expression/ops/reduce.h"
namespace mindspore {
namespace NN {
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) {
auto lite_node = lite::NN::SoftmaxCrossEntropy(cfg);
return NodeImpl::Connect(lite_node);
}
} // namespace NN
namespace lite {
SoftmaxCrossEntropyM::SoftmaxCrossEntropyM() {
auto param = calloc(1, sizeof(OpParameter));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate parameter";
return;
}
expr()->SetSize(C2NUM);
SetOpParam(param);
set_name("SoftmaxCrossEntropy");
set_primitive(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits);
EXPR e(this);
e.SetSize(0);
expr_.emplace_back(e);
}
Node *SoftmaxCrossEntropyM::GetReductionNode(const std::string &mode, const std::vector<int> &axis) {
if (mode == "mean") {
return NN::ReduceMean(false, axis);
} else if (mode == "sum") {
return NN::ReduceSum(false, axis);
} else {
return nullptr;
}
}
SoftmaxCrossEntropyM::SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg) : SoftmaxCrossEntropyM() {
std::vector<int> axis = {0};
reduce_ = GetReductionNode(cfg.reduction, axis);
if (reduce_ != nullptr) {
PushOp(reduce_);
}
}
std::vector<EXPR *> SoftmaxCrossEntropyM::construct(const std::vector<EXPR *> &inputs) {
auto y = Node::construct(inputs);
if (reduce_ != nullptr) {
y = (*reduce_)({y.front()});
}
return y;
}
std::vector<EXPR *> SoftmaxCrossEntropyM::Grad(EXPR *expr) { return {this->expr(1)}; }
int SoftmaxCrossEntropyM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::SoftmaxCrossEntropyWithLogitsT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
namespace NN {
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) {
auto s = new (std::nothrow) SoftmaxCrossEntropyM(cfg);
if (s == nullptr) {
MS_LOG(ERROR) << "Cannot allocate softmax node";
}
return s;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_
#include <vector>
#include <memory>
#include <string>
#include "src/expression/node.h"
#include "inner/model_generated.h"
#include "include/api/net.h"
namespace mindspore {
namespace lite {
class SoftmaxCrossEntropyM : public Node {
public:
SoftmaxCrossEntropyM();
explicit SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg);
std::vector<EXPR *> Grad(EXPR *expr) override;
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
private:
Node *GetReductionNode(const std::string &mode, const std::vector<int> &axis);
Node *reduce_ = nullptr;
};
namespace NN {
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
} // namespace NN
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/tile.h"
#include <memory>
#include "src/expression/ops.h"
#include "nnacl/base/tile_base.h"
namespace mindspore {
namespace lite {
TileM::TileM(const std::vector<int> &multiples) : Node() {
expr()->SetSize(C2NUM);
TileParameter *param = reinterpret_cast<TileParameter *>(calloc(1, sizeof(TileParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << " cannot allocate ConvParameter";
return;
}
SetOpParam(param);
set_name(UniqueName("Tile"));
set_primitive(schema::PrimitiveType_TileFusion);
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(multiples.size())}, kNumberTypeInt32, KHWC, "axis",
multiples.data());
}
int TileM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto tile_param = reinterpret_cast<const TileParameter *>(OpParam());
auto prim = new (std::nothrow) schema::TileFusionT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
for (size_t i = 0; i < tile_param->dims_size_; i++) {
prim->dims.push_back(tile_param->dims_[i]);
}
cnode->primitive->value.value = prim;
return RET_OK;
}
namespace NN {
Node *Tile(const std::vector<int> &multiples) {
auto node = new (std::nothrow) TileM(multiples);
if (node == nullptr) {
MS_LOG(ERROR) << "Cannot allocate tile node";
}
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class TileM : public Node {
public:
TileM() = default;
explicit TileM(const std::vector<int> &multiples);
Param *weight() override { return input(1)->node()->data(); }
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
};
namespace NN {
Node *Tile(const std::vector<int> &multiples);
}
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops/transpose.h"
#include <memory>
#include "nnacl/transpose.h"
#include "inner/model_generated.h"
#include "src/expression/import.h"
namespace mindspore {
namespace lite {
TransposeM::TransposeM(const std::vector<int> &vector) {
auto param = calloc(1, sizeof(TransposeParameter));
if (param == nullptr) {
MS_LOG(ERROR) << "Cannot allocate transpose parameter";
return;
}
SetOpParam(param);
expr()->SetSize(C2NUM);
set_primitive(schema::PrimitiveType_Transpose);
std::vector<int> dims = {static_cast<int>(vector.size())};
set_name(UniqueName("Transpose"));
CreateConstTensor(C1NUM, dims, kNumberTypeInt32, KHWC, "axis", vector.data());
}
std::vector<int> TransposeM::Invert(const std::vector<int> &vector) {
std::vector<int> res;
for (size_t i = 0; i < vector.size(); i++) {
int idx = static_cast<int>(i);
auto val = std::find_if(vector.begin(), vector.end(), [idx](int x) { return (x == idx) ? true : false; });
if (val == vector.end()) {
MS_LOG(ERROR) << "Wrong index for " << idx;
return {};
}
res.push_back(std::distance(vector.begin(), val));
}
return res;
}
std::vector<EXPR *> TransposeM::Grad(EXPR *yt) {
auto tensor = input(1)->node();
auto data = tensor->data();
auto vec = data->Extract<int>();
auto invert = Invert(vec);
auto tran = new (std::nothrow) TransposeM(invert);
if (tran == nullptr) {
MS_LOG(ERROR) << "Cannot allocate transpose grad";
return {};
}
tran->set_name(kGradName + "/" + name() + "/" + tran->name());
PushOp(tran);
auto grad = (*tran)({yt});
return grad;
}
int TransposeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
auto prim = new (std::nothrow) schema::TransposeT;
if (prim == nullptr) {
MS_LOG(ERROR) << "Cannot allocate primitive";
return RET_ERROR;
}
cnode->primitive->value.value = prim;
return RET_OK;
}
static ImportReg reg(schema::PrimitiveType_Transpose, ReturnNode<TransposeM>);
namespace NN {
Node *Transpose(const std::vector<int> &permute) {
auto node = new (std::nothrow) TransposeM(permute);
return node;
}
} // namespace NN
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_
#include <vector>
#include <memory>
#include "src/expression/node.h"
#include "inner/model_generated.h"
namespace mindspore {
namespace lite {
class TransposeM : public Node {
public:
TransposeM() = default;
explicit TransposeM(const std::vector<int> &vector);
static EXPR *TransposeCHW2HWC(EXPR *in) {
std::vector<int> res = {0, 2, 3, 1};
auto trans = new (std::nothrow) TransposeM(res);
if (trans == nullptr) {
return nullptr;
}
return (*trans)({in}).front();
}
static EXPR *TransposeHWC2CHW(EXPR *in) {
std::vector<int> res = {0, 3, 1, 2};
auto trans = new (std::nothrow) TransposeM(res);
if (trans == nullptr) {
return nullptr;
}
return (*trans)({in}).front();
}
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
std::vector<EXPR *> Grad(EXPR *yt) override;
private:
std::vector<int> Invert(const std::vector<int> &vec);
};
namespace NN {
Node *Transpose(const std::vector<int> &permute);
} // namespace NN
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_

View File

@ -0,0 +1,275 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/ops_utils.h"
#include <set>
#include <algorithm>
namespace mindspore {
namespace lite {
enum class State {
SAME,
X_ONE,
Y_ONE,
};
bool CompareShape(const std::vector<int> &x_shape, const std::vector<int> &y_shape) {
if (x_shape.size() != y_shape.size()) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape.at(i) != y_shape.at(i)) {
return false;
}
}
return true;
}
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
MS_ASSERT(grad_x_reduce_idx != nullptr);
MS_ASSERT(grad_y_reduce_idy != nullptr);
const size_t n = reverse_x.size();
if (reverse_y.size() < n) {
MS_LOG_ERROR << "The size of reverse_y is less than the size of reverse_x.";
}
for (size_t i = 0; i < n; ++i) {
State curr = State::SAME;
const int x_i = reverse_x[i];
const int y_i = reverse_y[i];
const int reduce_idx = (n - 1 - i);
if (x_i == y_i) {
curr = State::SAME;
} else if (x_i == 1) {
grad_x_reduce_idx->push_back(reduce_idx);
curr = State::X_ONE;
} else if (y_i == 1) {
grad_y_reduce_idy->push_back(reduce_idx);
curr = State::Y_ONE;
} else {
MS_LOG_ERROR << "not compatible shape input for BroadcastGradientArgs";
}
if (curr == State::SAME && x_i == 1) {
grad_x_reduce_idx->push_back(reduce_idx);
grad_y_reduce_idy->push_back(reduce_idx);
continue;
}
}
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
}
std::vector<std::vector<int>> BroadcastGradientArgs::operator()() {
std::vector<std::vector<int>> input_dim(kInNum);
input_dim[0] = dim0_;
input_dim[1] = dim1_;
auto same_shape = CompareShape(dim0_, dim1_);
if (same_shape) {
return {{}, {}};
}
std::vector<int> reverse_x;
std::vector<int> reverse_y;
(void)std::transform(dim0_.rbegin(), dim0_.rend(), std::back_inserter(reverse_x), [](const int &v) { return v; });
(void)std::transform(dim1_.rbegin(), dim1_.rend(), std::back_inserter(reverse_y), [](const int &v) { return v; });
if (reverse_x.size() > reverse_y.size()) {
reverse_y.resize(reverse_x.size(), 1);
} else {
reverse_x.resize(reverse_y.size(), 1);
}
std::vector<int> grad_x_reduce_idx;
std::vector<int> grad_y_reduce_idy;
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
return {grad_x_reduce_idx, grad_y_reduce_idy};
}
void DynamicBroadcastGradientArgs::AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx,
std::vector<bool> current_is_one, bool none_is_one,
const size_t largest_rank, size_t j) {
for (size_t i = 0; i < kInNum; ++i) {
if (current_is_one[i] && !none_is_one) {
(void)(*grad_reduce_idx)[i].emplace_back(largest_rank - 1 - j);
}
}
}
void DynamicBroadcastGradientArgs::UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) {
for (size_t i = 0; i < kInNum; ++i) {
(*prev_is_one)[i] = current_is_one[i];
}
}
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::GetGradientIndices(
const std::vector<std::vector<int>> &reverse_shape, const size_t largest_rank) {
std::vector<std::vector<int>> grad_reduce_idx(kInNum);
// indices of j-th component of each input.
std::vector<bool> prev_is_one(kInNum);
std::vector<bool> current_is_one(kInNum);
for (size_t i = 0; i < kInNum; ++i) {
prev_is_one[i] = false;
current_is_one[i] = false;
}
bool set_one = false;
for (size_t j = 0; j < largest_rank; ++j) {
int output_dim = -1;
bool output_dim_set = false;
bool none_is_one = true;
// Find which indices are 1.
for (size_t i = 0; i < kInNum; ++i) {
if (reverse_shape[i][j] == 1) {
current_is_one[i] = true;
none_is_one = false;
} else {
current_is_one[i] = false;
if (!output_dim_set || reverse_shape[i][j] == static_cast<int>(output_dim)) {
output_dim = reverse_shape[i][j];
output_dim_set = true;
} else {
std::cout << "Input[0] and input[1] Cannot broadcast!";
}
}
}
// All dimensions are 1.
if (!output_dim_set) {
for (size_t i = 0; i < kInNum; ++i) {
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
}
continue;
} else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) {
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
} else {
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
}
set_one = true;
UpdatePreIsOne(&prev_is_one, current_is_one);
}
return grad_reduce_idx;
}
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::CalculateOutput(const std::vector<std::vector<int>> &x) {
std::vector<std::vector<int>> grad_reduce_idx(kInNum);
bool all_equal = true;
size_t largest_rank = 0;
for (size_t i = 0; i < kInNum; ++i) {
if (x[i] != x[0]) {
all_equal = false;
}
if (x[i].size() > largest_rank) {
largest_rank = x[i].size();
}
}
if (all_equal) {
return grad_reduce_idx;
}
// Reverse input the shapes
std::vector<std::vector<int>> reverse_shape(kInNum);
for (size_t i = 0; i < kInNum; ++i) {
reverse_shape[i] = x[i];
std::reverse(reverse_shape[i].begin(), reverse_shape[i].end());
}
// 1-extend and align all vectors.
for (size_t i = 0; i < kInNum; ++i) {
if (reverse_shape[i].size() < largest_rank) {
reverse_shape[i].resize(largest_rank, 1);
}
}
grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank);
return grad_reduce_idx;
}
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::SetOutputValue(
const std::vector<std::vector<int>> &grad_reduce_idx, const std::vector<std::vector<int>> &input_dim) {
std::vector<std::vector<int>> output(kInNum);
for (size_t index = 0; index < kInNum; ++index) {
auto idx_num = grad_reduce_idx[index].size();
for (size_t k = 0; k < idx_num; ++k) {
output[index].push_back(grad_reduce_idx[index][idx_num - 1 - k]);
}
if (idx_num == 0) {
auto input_num = input_dim[index].size();
for (size_t k = 0; k < input_num; ++k) {
output[index].push_back(k);
}
}
}
return output;
}
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::operator()() {
std::vector<std::vector<int>> input_dim(kInNum);
input_dim[0] = dim0_;
input_dim[1] = dim1_;
auto grad_reduce_idx = CalculateOutput(input_dim);
auto output = SetOutputValue(grad_reduce_idx, input_dim);
return output;
}
std::vector<int> VectorDiv::operator()(const std::vector<int> &x, const std::vector<int> &d) {
if (d.size() != x.size()) {
MS_LOG(ERROR) << "x and divider must have same size";
return {};
}
std::vector<int> res;
for (size_t i = 0; i < d.size(); i++) {
auto x_value = x.at(i);
auto d_value = d.at(i);
if (d_value == 0) {
MS_LOG(ERROR) << "Divisor is zero";
return {};
}
if ((x_value % d_value) != 0) {
MS_LOG(ERROR) << "x and d and not dividable";
}
auto r = x_value / d_value;
res.push_back(r);
}
return res;
}
std::vector<int> ShapeReduce::operator()(const std::vector<int> &x_shape, const std::vector<int> &axis) {
int x_rank = x_shape.size();
std::set<int> axis_set;
auto min = -x_rank;
auto max = x_rank - 1;
for (auto &elem : axis) {
if (elem > max || elem < min) {
MS_LOG(ERROR) << "illegal axis value";
return {};
}
axis_set.insert(elem);
}
std::vector<int> res;
for (int i = 0; i < x_rank; i++) {
if (axis_set.count(i) || axis_set.count(i - x_rank)) {
res.push_back(1);
} else {
res.push_back(x_shape.at(i));
}
}
return res;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_
#include "include/api/cfg.h"
#include "src/expression/net.h"
#include "vector"
namespace mindspore {
namespace lite {
class BroadcastGradientArgs {
public:
BroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {}
std::vector<std::vector<int>> operator()();
private:
static const int kInNum = 2;
const std::vector<int> &dim0_;
const std::vector<int> &dim1_;
};
class DynamicBroadcastGradientArgs {
public:
DynamicBroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {}
std::vector<std::vector<int>> operator()();
private:
void AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx, std::vector<bool> current_is_one,
bool none_is_one, const size_t largest_rank, size_t j);
void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one);
std::vector<std::vector<int>> GetGradientIndices(const std::vector<std::vector<int>> &reverse_shape,
const size_t largest_rank);
std::vector<std::vector<int>> CalculateOutput(const std::vector<std::vector<int>> &x);
std::vector<std::vector<int>> SetOutputValue(const std::vector<std::vector<int>> &grad_reduce_idx,
const std::vector<std::vector<int>> &input_dim);
static const int kInNum = 2;
const std::vector<int> &dim0_;
const std::vector<int> &dim1_;
};
class VectorDiv {
public:
VectorDiv() {}
std::vector<int> operator()(const std::vector<int> &x, const std::vector<int> &d);
};
class ShapeReduce {
public:
ShapeReduce() {}
std::vector<int> operator()(const std::vector<int> &x_shape, const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/param.h"
#include <random>
#include <algorithm>
#include <string>
#include "include/errorcode.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
constexpr float kZero = 0.0f;
constexpr float kOne = 1.0f;
namespace mindspore {
namespace lite {
int Param::Fill(Mode mode) {
std::default_random_engine engine{static_cast<unsigned int>(0)};
std::vector<float> data(size_);
switch (mode) {
case NORMAL: {
constexpr float scale = 0.01;
std::normal_distribution<float> n{0, 1};
std::generate_n(data.begin(), size_, [&]() { return n(engine); });
(void)std::transform(data.begin(), data.end(), data.begin(), [](float x) { return x * scale; });
break;
}
case UNIFORM: {
constexpr float scale = 0.07;
std::uniform_real_distribution<float> u{-1.0, 1.0};
std::generate_n(data.begin(), size_, [&]() { return u(engine) * scale; });
break;
}
case ZEROS:
std::fill_n(data.begin(), size_, kZero);
break;
case ONES:
std::fill_n(data.begin(), size_, kOne);
break;
case NOT_SUPPORTED:
return RET_ERROR;
}
Copy(data);
return RET_OK;
}
Param::Mode Param::String2Enum(std::string mode) {
(void)std::transform(mode.begin(), mode.end(), mode.begin(), ::tolower);
if (mode == "normal") return NORMAL;
if (mode == "uniform") return UNIFORM;
if (mode == "ones") return ONES;
if (mode == "zeors") return ZEROS;
return NOT_SUPPORTED;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
namespace mindspore {
namespace lite {
class Param {
public:
enum Mode { NORMAL, UNIFORM, ONES, ZEROS, NOT_SUPPORTED };
int Fill(Mode type);
static Mode String2Enum(std::string);
std::vector<uint8_t> &data() { return data_; }
size_t Load(std::string file_name, size_t offset = 0) { return data_.size() * sizeof(float); }
size_t Load(std::ifstream &s, int offset = 0) { return data_.size() * sizeof(float); }
void SetSize(size_t size) { size_ = size; }
template <typename T>
void Copy(const T *data, size_t size) {
auto cast_data = reinterpret_cast<const uint8_t *>(data);
data_ = decltype(data_)(cast_data, cast_data + size * sizeof(T) / sizeof(uint8_t));
}
template <typename T>
void Copy(const std::vector<T> data) {
Copy<T>(data.data(), data.size());
}
template <typename T>
std::vector<T> Extract() {
T *num = reinterpret_cast<T *>(data_.data());
std::vector<T> res(num, num + data_.size() / sizeof(T));
return res;
}
private:
size_t size_;
std::vector<uint8_t> data_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/expression/sequential.h"
namespace mindspore {
namespace lite {
void Sequential::Add(Node *node) { PushOp(node); }
std::vector<EXPR *> Sequential::construct(const std::vector<EXPR *> &inputs) {
auto x = inputs;
for (auto &node : ops_) {
x = (*node)({x.front()});
}
return x;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_
#define MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_
#include <vector>
#include "src/expression/net.h"
namespace mindspore {
namespace lite {
class Sequential : public Net {
public:
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
void Add(Node *node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_

View File

@ -25,6 +25,7 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel { namespace mindspore::kernel {
Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { Convolution1x1CPUKernel::~Convolution1x1CPUKernel() {
FreeTmpBuffer(); FreeTmpBuffer();
if (matmul_param_ != nullptr) { if (matmul_param_ != nullptr) {
delete matmul_param_; delete matmul_param_;
matmul_param_ = nullptr; matmul_param_ = nullptr;
@ -185,7 +186,7 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) {
float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
float *thread_pack_input = pack_input_ + task_id * row_tile_ * matmul_param_->deep_; float *thread_pack_input = pack_input_ + task_id * row_tile_ * matmul_param_->deep_;
float *thread_output_ptr; float *thread_output_ptr = nullptr;
if (out_tensors()[0]->format() != NC4HW4) { if (out_tensors()[0]->format() != NC4HW4) {
thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_;
} else { } else {

View File

@ -127,10 +127,12 @@ int ConvolutionDepthwiseCPUKernel::MallocWeightBiasData() {
} }
} }
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, channel * sizeof(float)); CHECK_LESS_RETURN(MAX_MALLOC_SIZE, channel * sizeof(float));
bias_data_ = malloc(channel * sizeof(float));
if (bias_data_ == nullptr) { if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed."; bias_data_ = malloc(channel * sizeof(float));
return RET_ERROR; if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
} }
memset(bias_data_, 0, channel * sizeof(float)); memset(bias_data_, 0, channel * sizeof(float));
return RET_OK; return RET_OK;

View File

@ -1,2 +1,2 @@
Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!! Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!!
1034600 1040696

View File

@ -48,5 +48,7 @@ vae
unified_api code_example unified_api code_example
train_lenet code_example train_lenet code_example
train_lenet_java code_example train_lenet_java code_example
lenet expression
mobilenetv2 expression noarm32
# LAST # LAST
#test_resize inputShapes 16,10,10,1:16,10,10,1 0.5 #test_resize inputShapes 16,10,10,1:16,10,10,1 0.5

View File

@ -1,2 +1,2 @@
Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review. Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review.
1034600 1040696

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
source ./scripts/base_functions.sh source ./scripts/base_functions.sh
version=1.5.0 version=1.6.1
# Run Export on x86 platform and create output test files: # Run Export on x86 platform and create output test files:
docker_image=mindspore_build:210301 docker_image=mindspore_build:210301
@ -69,6 +69,9 @@ function Run_Converter() {
model_prefix="${line_array[0]}_head" model_prefix="${line_array[0]}_head"
model_name=${line_array[0]}'_head' model_name=${line_array[0]}'_head'
fi fi
if [[ "${expression}" == "1" ]]; then
model_prefix="${model_name}_fwd"
fi
if [[ ${check_convert} == "1" ]]; then if [[ ${check_convert} == "1" ]]; then
ms_file=$ms_models_path'/'$model_name'.ms' ms_file=$ms_models_path'/'$model_name'.ms'
if [ -f "$ms_file" ]; then if [ -f "$ms_file" ]; then
@ -80,7 +83,7 @@ function Run_Converter() {
{ {
echo ${model_name} >> "${run_converter_log_file}" echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_prefix}'.mindir --outputFile='${ms_models_path}'/'${model_name}' --trainModel=true' ${WEIGHT_QUANT} >> "${run_converter_log_file}" echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_prefix}'.mindir --outputFile='${ms_models_path}'/'${model_name}' --trainModel=true' ${WEIGHT_QUANT} >> "${run_converter_log_file}"
./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_prefix}.mindir --outputFile=${ms_models_path}/${model_name} --trainModel=true ${WEIGHT_QUANT} ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_prefix}.mindir --outputFile=${ms_models_path}/${model_name} --trainModel=true ${no_opt} ${WEIGHT_QUANT}
if [ $? = 0 ]; then if [ $? = 0 ]; then
converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else else
@ -118,9 +121,12 @@ function should_run_example() {
continue continue
fi fi
if [[ $model_name == "$1" ]]; then if [[ $model_name == "$1" ]]; then
if [[ ${line_array[1]} == "code_example" ]]; then if [[ "${line_array[1]}" == "code_example" ]]; then
ret=1 ret=1
fi fi
if [[ "${line_array[2]}" == "expression" ]]; then
ret=2
fi
fi fi
done < ${models_ms_train_config} done < ${models_ms_train_config}
return $ret return $ret
@ -146,6 +152,9 @@ function parse_line() {
enable_transfer=0 enable_transfer=0
suffix_print="" suffix_print=""
check_convert=0 check_convert=0
do_api=false
no_opt="--NoFusion=false"
expression=0
model_name=${line_array[0]}_train model_name=${line_array[0]}_train
while [[ $i < ${#line_array[@]} ]] ; do while [[ $i < ${#line_array[@]} ]] ; do
case ${line_array[i]} in case ${line_array[i]} in
@ -189,6 +198,12 @@ function parse_line() {
i=$(($i+1)) i=$(($i+1))
inputShapes=${line_array[i]} inputShapes=${line_array[i]}
;; ;;
"expression")
model_name="${line_array[0]}_expr"
do_api=true
no_opt="--NoFusion=true"
expression=1
;;
*) *)
check=`echo "${line_array[i]}" | grep -E '^\-?[0-9]*\.?[0-9]+$'` check=`echo "${line_array[i]}" | grep -E '^\-?[0-9]*\.?[0-9]+$'`
if [ "${check}" != "" ] ; then if [ "${check}" != "" ] ; then
@ -214,12 +229,14 @@ function Run_x86() {
continue continue
fi fi
local model_prefix=${line_array[0]} local model_prefix=${line_array[0]}
local bb_model_file=""
local log_suffix="_train" local log_suffix="_train"
parse_line x86 parse_line x86
if [[ "$?" == "1" ]]; then continue; fi if [[ "$?" == "1" ]]; then continue; fi
if [[ "${expression}" == "1" ]]; then
model_prefix=${model_prefix}_expr
fi
local model_file="${ms_models_path}/${model_name}.ms" local model_file="${ms_models_path}/${model_name}.ms"
local bb_model_file=""
local export_file="${ms_models_path}/${model_name}_tod"
local inference_file="${ms_models_path}/${model_name}_infer" local inference_file="${ms_models_path}/${model_name}_infer"
if [[ "${enable_transfer}" == "1" ]]; then if [[ "${enable_transfer}" == "1" ]]; then
model_file="${ms_models_path}/${model_prefix}_head.ms" model_file="${ms_models_path}/${model_prefix}_head.ms"
@ -247,7 +264,8 @@ function Run_x86() {
--exportFile=${export_file} \ --exportFile=${export_file} \
--virtualBatch=${virtual_batch} \ --virtualBatch=${virtual_batch} \
--inputShapes=${inputShapes} \ --inputShapes=${inputShapes} \
--lossName=${loss_name} " >> ${run_x86_log_file} --lossName=${loss_name} \
--unifiedApi=${do_api}" >> ${run_x86_log_file}
${run_valgrind} ./tools/benchmark_train/benchmark_train \ ${run_valgrind} ./tools/benchmark_train/benchmark_train \
--modelFile=${model_file} \ --modelFile=${model_file} \
--bbModelFile=${bb_model_file} \ --bbModelFile=${bb_model_file} \
@ -257,7 +275,8 @@ function Run_x86() {
--exportFile=${export_file} \ --exportFile=${export_file} \
--virtualBatch=${virtual_batch} \ --virtualBatch=${virtual_batch} \
--inputShapes=${inputShapes} \ --inputShapes=${inputShapes} \
--lossName=${loss_name} >> "${run_x86_log_file}" --lossName=${loss_name} \
--unifiedApi=${do_api} >> "${run_x86_log_file}"
if [ $? = 0 ]; then if [ $? = 0 ]; then
run_result='x86'${log_suffix}': '${model_name}''${suffix_print}' pass'; echo ${run_result} >> ${run_benchmark_train_result_file} run_result='x86'${log_suffix}': '${model_name}''${suffix_print}' pass'; echo ${run_result} >> ${run_benchmark_train_result_file}
else else
@ -327,24 +346,27 @@ function Run_arm() {
echo "cd ${tmp_dir}" > ${adb_cmd_file} echo "cd ${tmp_dir}" > ${adb_cmd_file}
echo 'chmod 777 benchmark_train' >> ${adb_cmd_file} echo 'chmod 777 benchmark_train' >> ${adb_cmd_file}
adb -s ${device_id} shell < ${adb_cmd_file} adb -s ${device_id} shell < ${adb_cmd_file}
local fail=0 local fail=0
# Run mindir converted train models: # Run mindir converted train models:
while read line; do while read line; do
local line_array local line_array
local bb_model_file=""
LFS=" " read -r -a line_array <<< ${line} LFS=" " read -r -a line_array <<< ${line}
if [[ ${line_array[0]} == \#* || ${line_array[0]} == "" ]]; then if [[ ${line_array[0]} == \#* || ${line_array[0]} == "" ]]; then
continue continue
fi fi
parse_line $1
if [[ "$?" == "1" ]]; then continue; fi
local model_prefix=${line_array[0]} local model_prefix=${line_array[0]}
if [[ ${expression} == "1" ]]; then
model_prefix=${model_prefix}_expr
fi
local run_result="" local run_result=""
local log_suffix="_train" local log_suffix="_train"
parse_line $1
if [[ "$?" == "1" ]]; then continue; fi if [[ "$?" == "1" ]]; then continue; fi
local export_file="${tmp_dir}/${model_name}_tod" local export_file="${tmp_dir}/${model_name}_tod"
local inference_file="${tmp_dir}/${model_name}_infer" local inference_file="${tmp_dir}/${model_name}_infer"
local model_file="${model_name}.ms" local model_file="${model_name}.ms"
local bb_model_file=""
if [[ "${enable_transfer}" == "1" ]]; then if [[ "${enable_transfer}" == "1" ]]; then
model_file="${model_prefix}_head.ms" model_file="${model_prefix}_head.ms"
bb_model_file="${model_prefix}_bb.ms" bb_model_file="${model_prefix}_bb.ms"
@ -388,7 +410,8 @@ function Run_arm() {
--exportFile=${export_file} \ --exportFile=${export_file} \
--virtualBatch=${virtual_batch} \ --virtualBatch=${virtual_batch} \
--inputShapes=${inputShapes} \ --inputShapes=${inputShapes} \
--lossName=${loss_name} --lossName=${loss_name} \
--unifiedApi=${do_api}
ENDM ENDM
) )
echo "${adb_cmd}" >> ${run_arm_log_file} echo "${adb_cmd}" >> ${run_arm_log_file}
@ -443,11 +466,16 @@ function Run_CodeExamples() {
should_run_example "unified_api" should_run_example "unified_api"
should_run=$? should_run=$?
local expression_flag=""
if [[ "$should_run" == "2" ]]; then
expression_flag="-x"
should_run=1
fi
if [[ "$should_run" == "1" ]]; then if [[ "$should_run" == "1" ]]; then
cd ${basepath}/../../examples/unified_api || exit 1 cd ${basepath}/../../examples/unified_api || exit 1
chmod 777 ./prepare_and_run.sh chmod 777 ./prepare_and_run.sh
chmod 777 ./*/*.sh chmod 777 ./*/*.sh
./prepare_and_run.sh -D ${datasets_path}/mnist -r ${tarball_path} -t ${target} -m ${models_path}/code_example.mindir -e 1 >> ${run_code_examples_log_file} ./prepare_and_run.sh -D ${datasets_path}/mnist -r ${tarball_path} -t ${target} -m ${models_path}/code_example.mindir -e 1 ${expression_flag} >> ${run_code_examples_log_file}
if [ "$?" != "0" ]; then if [ "$?" != "0" ]; then
echo "Unified API prepare_and_run.sh failed" echo "Unified API prepare_and_run.sh failed"
exit 1 exit 1
@ -492,6 +520,10 @@ while getopts "r:c:m:d:i:e:vt:q:D:M:l:" opt; do
models_path=${OPTARG} models_path=${OPTARG}
echo "models_path is ${OPTARG}" echo "models_path is ${OPTARG}"
;; ;;
c)
models_ms_train_config=${OPTARG}
echo "models_ms_train_config ${models_ms_train_config}"
;;
i) i)
train_io_path=${OPTARG} train_io_path=${OPTARG}
echo "train_io_path is ${OPTARG}" echo "train_io_path is ${OPTARG}"
@ -539,7 +571,9 @@ config_folder="config_level0"
if [[ ${level} == "level1" ]];then if [[ ${level} == "level1" ]];then
config_folder="config_level1" config_folder="config_level1"
fi fi
models_ms_train_config=${basepath}/../${config_folder}/models_ms_train.cfg if [[ "${models_ms_train_config}" == "" ]]; then
models_ms_train_config=${basepath}/../${config_folder}/models_ms_train.cfg
fi
if [[ $train_io_path == "" ]]; then if [[ $train_io_path == "" ]]; then
echo "train_io path is empty" echo "train_io path is empty"

View File

@ -5,10 +5,23 @@ set(COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
) )
set(TEST_SRC
${CMAKE_CURRENT_SOURCE_DIR}/main.cc
${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
)
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
set(TEST_SRC
${TEST_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/net_runner.cc
)
endif()
add_executable(benchmark_train add_executable(benchmark_train
${CMAKE_CURRENT_SOURCE_DIR}/main.cc ${TEST_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
${COMMON_SRC}) ${COMMON_SRC})
if(WIN32) if(WIN32)
add_dependencies(benchmark_train fbs_src mindspore-lite_static mindspore-lite-train_static) add_dependencies(benchmark_train fbs_src mindspore-lite_static mindspore-lite-train_static)
else() else()

View File

@ -0,0 +1,371 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/benchmark_train/net_runner.h"
#include "tools/benchmark_train/net_train.h"
#include <getopt.h>
#include <malloc.h>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <fstream>
#include <utility>
#include <chrono>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/serialization.h"
#include "include/api/callback/loss_monitor.h"
#include "include/api/metrics/accuracy.h"
#include "include/api/callback/ckpt_saver.h"
#include "include/api/callback/train_accuracy.h"
#include "include/api/callback/lr_scheduler.h"
#include "include/dataset/datasets.h"
#include "include/dataset/vision_lite.h"
#include "include/dataset/transforms.h"
#include "include/api/cfg.h"
#include "include/api/net.h"
using mindspore::AccuracyMetrics;
using mindspore::Model;
using mindspore::TrainAccuracy;
using mindspore::TrainCallBack;
using mindspore::TrainCallBackData;
using mindspore::dataset::Dataset;
using mindspore::dataset::Mnist;
using mindspore::dataset::SequentialSampler;
using mindspore::dataset::TensorOperation;
using mindspore::dataset::transforms::TypeCast;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Resize;
constexpr int kNCHWCDim = 2;
constexpr int kPrintTimes = 100;
constexpr float kBetta1 = 0.9f;
constexpr float kBetta2 = 0.999f;
class Rescaler : public mindspore::TrainCallBack {
public:
explicit Rescaler(float scale) : scale_(scale) {
if (scale_ == 0) {
scale_ = 1.0;
}
}
~Rescaler() override = default;
void StepBegin(const mindspore::TrainCallBackData &cb_data) override {
auto inputs = cb_data.model_->GetInputs();
auto *input_data = reinterpret_cast<float *>(inputs.at(0).MutableData());
for (int k = 0; k < inputs.at(0).ElementNum(); k++) input_data[k] /= scale_;
}
private:
float scale_ = 1.0;
};
/* This is an example of a user defined Callback to measure memory and latency of execution */
class Measurement : public mindspore::TrainCallBack {
public:
explicit Measurement(unsigned int epochs)
: time_avg_(std::chrono::duration<double, std::milli>(0)), epochs_(epochs) {}
~Measurement() override = default;
void EpochBegin(const mindspore::TrainCallBackData &cb_data) override {
start_time_ = std::chrono::high_resolution_clock::now();
}
mindspore::CallbackRetValue EpochEnd(const mindspore::TrainCallBackData &cb_data) override {
end_time_ = std::chrono::high_resolution_clock::now();
auto time = std::chrono::duration<double, std::milli>(end_time_ - start_time_);
time_avg_ += time;
return mindspore::kContinue;
}
void End(const mindspore::TrainCallBackData &cb_data) override {
if (epochs_ > 0) {
std::cout << "AvgRunTime: " << time_avg_.count() / epochs_ << " ms" << std::endl;
}
struct mallinfo info = mallinfo();
std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl;
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_time_;
std::chrono::time_point<std::chrono::high_resolution_clock> end_time_;
std::chrono::duration<double, std::milli> time_avg_;
unsigned int epochs_;
};
NetRunner::~NetRunner() {
if (model_ != nullptr) {
delete model_;
}
if (graph_ != nullptr) {
delete graph_;
}
}
mindspore::Status NetRunner::InitAndFigureInputs() {
auto context = std::make_shared<mindspore::Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(enable_fp16_);
context->MutableDeviceInfo().push_back(cpu_context);
graph_ = new (std::nothrow) mindspore::Graph(mindspore::Graph::Type::kExpressionGraph);
if (graph_ == nullptr) {
std::cout << "Cannot allocate graph" << std::endl;
return mindspore::kLiteMemoryFailed;
}
auto status = mindspore::Serialization::Load(ms_file_, mindspore::kMindIR, graph_);
if (status != mindspore::kSuccess) {
std::cout << "Error " << status << " during serialization of graph " << ms_file_;
return status;
}
auto net = std::make_unique<mindspore::Net>(*graph_);
auto input_shape = net->InputShape(0);
auto label_shape = net->OutputShape(0);
auto inputM = mindspore::NN::Input(input_shape);
auto labelM = mindspore::NN::Input(label_shape);
auto label = labelM->Create("label");
auto input = inputM->Create("input");
auto cfg = std::make_shared<mindspore::TrainCfg>();
if (enable_fp16_) {
cfg.get()->optimization_level_ = mindspore::kO2;
}
model_ = new (std::nothrow) mindspore::Model();
if (model_ == nullptr) {
std::cout << "model allocation failed" << std::endl;
return mindspore::kLiteMemoryFailed;
}
mindspore::SoftMaxCrossEntropyCfg softmax_ce_cfg;
softmax_ce_cfg.reduction = "none";
auto netWithLoss = mindspore::NN::GraphWithLoss(graph_, mindspore::NN::SoftmaxCrossEntropy(softmax_ce_cfg));
mindspore::AdamConfig AdamCfg;
AdamCfg.beta1_ = kBetta1;
AdamCfg.beta2_ = kBetta2;
AdamCfg.eps_ = 1e-8;
AdamCfg.learning_rate_ = 1e-2;
auto optimizer = mindspore::NN::Adam(net->trainable_params(), AdamCfg);
status = model_->Build(mindspore::GraphCell(*netWithLoss), optimizer, {input, label}, context, cfg);
if (status != mindspore::kSuccess) {
std::cout << "Error " << status << " during build of model " << ms_file_ << std::endl;
return status;
}
delete graph_;
graph_ = nullptr;
auto inputs = model_->GetInputs();
if (inputs.size() < 1) {
return mindspore::kLiteError;
}
auto nhwc_input_dims = inputs.at(0).Shape();
batch_size_ = nhwc_input_dims.at(0);
h_ = nhwc_input_dims.at(1);
w_ = nhwc_input_dims.at(kNCHWCDim);
return mindspore::kSuccess;
}
int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
std::cout << "================ Comparing Forward Output data ================" << std::endl;
float total_bias = 0;
int total_size = 0;
bool has_error = false;
int i = 1;
for (auto &tensor : outputs) {
std::cout << "output is tensor " << tensor.Name() << "\n";
auto output = tensor.Data();
size_t size;
std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size));
if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr";
std::cout << "ReadFile return nullptr" << std::endl;
return mindspore::kLiteNullptr;
}
if (size != tensor.DataSize()) {
MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
<< ", read size: " << size;
std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
<< ", read size: " << size << std::endl;
return mindspore::kLiteError;
}
float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
reinterpret_cast<const float *>(output.get()));
if (bias >= 0) {
total_bias += bias;
total_size++;
} else {
has_error = true;
break;
}
i++;
}
if (!has_error) {
float mean_bias;
if (total_size != 0) {
mean_bias = total_bias / total_size * kPrintTimes;
} else {
mean_bias = 0;
}
std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
<< " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
std::cout << "=======================================================" << std::endl << std::endl;
if (mean_bias > this->flags_->accuracy_threshold_) {
MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
return mindspore::kLiteError;
} else {
return mindspore::kSuccess;
}
} else {
MS_LOG(ERROR) << "Error in CompareData";
std::cout << "Error in CompareData" << std::endl;
std::cout << "=======================================================" << std::endl << std::endl;
return mindspore::kSuccess;
}
}
void NetRunner::CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out) {
constexpr int kPrintLen = 4;
int tensor_size = tensor.ElementNum();
const void *data = tensor.Data().get();
const float *fdata = reinterpret_cast<const float *>(data);
mindspore::DataType type = tensor.DataType();
std::cout << node_type << " " << in_out << id << std::endl;
std::cout << "tensor name: " << tensor.Name() << std::endl;
if ((tensor_size) == 0 || (data == nullptr)) {
std::cout << "Empty tensor" << std::endl;
return;
}
switch (type) {
case mindspore::DataType::kNumberTypeFloat32:
std::cout << "sum=" << mindspore::lite::TensorSum<float>(data, tensor_size) << std::endl;
std::cout << "data: ";
for (int i = 0; i <= kPrintLen && i < tensor_size; i++) {
std::cout << static_cast<float>(fdata[i]) << ", ";
}
std::cout << std::endl;
break;
case mindspore::DataType::kNumberTypeInt32:
std::cout << "sum=" << mindspore::lite::TensorSum<int>(data, tensor_size) << std::endl;
break;
default:
std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
break;
}
}
int NetRunner::InitCallbackParameter() {
// after callback
after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs,
const std::vector<mindspore::MSTensor> &after_outputs,
const mindspore::MSCallBackParam &call_param) {
if (after_inputs.empty()) {
MS_LOG(INFO) << "The num of after inputs is empty";
}
if (after_outputs.empty()) {
MS_LOG(INFO) << "The num of after outputs is empty";
}
if (flags_->layer_checksum_) {
for (size_t i = 0; i < after_inputs.size(); i++) {
CheckSum(after_inputs.at(i), call_param.node_type, i, "in");
}
for (size_t i = 0; i < after_outputs.size(); i++) {
CheckSum(after_outputs.at(i), call_param.node_type, i, "out");
}
std::cout << std::endl;
}
return true;
};
return false;
}
int NetRunner::RunOnce() {
auto inputs = model_->GetInputs();
std::vector<mindspore::MSTensor> output;
auto status = LoadInput(&inputs);
if (status != mindspore::kSuccess) {
std::cout << "cannot load data";
return status;
}
model_->SetTrainMode(true);
model_->RunStep(nullptr, nullptr);
model_->SetTrainMode(false);
model_->Predict(inputs, &output, nullptr, nullptr);
return CompareOutput(output);
}
int NetRunner::LoadInput(std::vector<mindspore::MSTensor> *ms_inputs) {
auto status = ReadInputFile(ms_inputs);
if (status != mindspore::kSuccess) {
std::cout << "Read Input File error, " << status << std::endl;
MS_LOG(ERROR) << "Read Input File error, " << status;
return status;
}
return mindspore::kSuccess;
}
int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) {
if (ms_inputs->empty()) {
std::cout << "no inputs to input" << std::endl;
return mindspore::kLiteError;
}
for (size_t i = 0; i < ms_inputs->size(); i++) {
auto cur_tensor = ms_inputs->at(i);
if (cur_tensor == nullptr) {
std::cout << "empty tensor " << i << std::endl;
MS_LOG(ERROR) << "empty tensor " << i;
}
size_t size;
std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size));
if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr";
std::cout << "ReadFile return nullptr" << std::endl;
return mindspore::kLiteNullptr;
}
auto tensor_data_size = cur_tensor.DataSize();
if (size != tensor_data_size) {
std::cout << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
<< " ,file_name: " << file_name.c_str() << std::endl;
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
<< " ,file_name: " << file_name.c_str();
return mindspore::kLiteError;
}
auto input_data = cur_tensor.MutableData();
memcpy(input_data, bin_buf.get(), tensor_data_size);
}
return mindspore::kSuccess;
}
int NetRunner::Main() {
ms_file_ = flags_->model_file_;
InitCallbackParameter();
auto status = InitAndFigureInputs();
if (status != mindspore::kSuccess) {
std::cout << "failed to initialize network" << std::endl;
return status.StatusCode();
}
return RunOnce();
}
int CallBack(mindspore::lite::NetTrainFlags *flags) {
NetRunner nr(flags);
return nr.Main();
}
int init = mindspore::lite::NetTrain::SetNr(CallBack);

View File

@ -0,0 +1,81 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_
#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_
#include <tuple>
#include <iomanip>
#include <map>
#include <vector>
#include <memory>
#include <string>
#include "include/api/model.h"
#include "include/api/graph.h"
#include "include/api/types.h"
#include "include/api/status.h"
#include "include/api/metrics/accuracy.h"
#include "include/dataset/datasets.h"
using mindspore::AccuracyMetrics;
using mindspore::dataset::Dataset;
namespace mindspore::lite {
class NetTrainFlags;
}
class NetRunner {
public:
int Main();
explicit NetRunner(mindspore::lite::NetTrainFlags *flags) : flags_(flags) {}
bool ReadArgs(int argc, int8_t *argv[]);
virtual ~NetRunner();
private:
void Usage();
mindspore::Status InitAndFigureInputs();
void CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out);
int InitCallbackParameter();
int TrainLoop();
float CalculateAccuracy(int max_tests = 0);
float GetLoss() const;
int RunOnce();
int CompareOutput(const std::vector<mindspore::MSTensor> &outputs);
int LoadInput(std::vector<mindspore::MSTensor> *ms_inputs);
int ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs);
mindspore::Model *model_ = nullptr;
mindspore::Graph *graph_ = nullptr;
std::shared_ptr<Dataset> train_ds_;
std::shared_ptr<Dataset> test_ds_;
std::shared_ptr<AccuracyMetrics> acc_metrics_;
std::string ms_file_ = "";
std::string data_dir_ = "";
unsigned int epochs_ = 10;
bool verbose_ = false;
bool enable_fp16_ = false;
int virtual_batch_ = -1;
int save_checkpoint_ = 0;
int batch_size_ = 32;
int h_ = 32;
int w_ = 32;
mindspore::lite::NetTrainFlags *flags_{nullptr};
mindspore::MSKernelCallBack after_call_back_;
};
#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_

View File

@ -24,6 +24,7 @@
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "tools/benchmark_train/net_runner.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "include/ms_tensor.h" #include "include/ms_tensor.h"
#include "include/context.h" #include "include/context.h"
@ -49,14 +50,20 @@ constexpr int kCPUBindFlag2 = 2;
constexpr int kCPUBindFlag1 = 1; constexpr int kCPUBindFlag1 = 1;
static const int kTHOUSAND = 1000; static const int kTHOUSAND = 1000;
namespace { std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr;
float *ReadFileBuf(const char *file, size_t *size) {
if (file == nullptr) { int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) {
nr_cb_ = param;
return 0;
}
float *NetTrain::ReadFileBuf(const std::string file, size_t *size) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr"; MS_LOG(ERROR) << "file is nullptr";
return nullptr; return nullptr;
} }
MS_ASSERT(size != nullptr); MS_ASSERT(size != nullptr);
std::string real_path = RealPath(file); std::string real_path = RealPath(file.c_str());
std::ifstream ifs(real_path); std::ifstream ifs(real_path);
if (!ifs.good()) { if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist"; MS_LOG(ERROR) << "file: " << real_path << " is not exist";
@ -83,7 +90,6 @@ float *ReadFileBuf(const char *file, size_t *size) {
return buf.release(); return buf.release();
} }
} // namespace
int NetTrain::GenerateRandomData(mindspore::tensor::MSTensor *tensor) { int NetTrain::GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
auto input_data = tensor->MutableData(); auto input_data = tensor->MutableData();
@ -832,7 +838,9 @@ int RunNetTrain(int argc, const char **argv) {
std::cerr << flags.Usage() << std::endl; std::cerr << flags.Usage() << std::endl;
return RET_OK; return RET_OK;
} }
if (flags.unified_api_) {
return NetTrain::RunNr(&flags);
}
NetTrain net_trainer(&flags); NetTrain net_trainer(&flags);
auto status = net_trainer.Init(); auto status = net_trainer.Init();
if (status != RET_OK) { if (status != RET_OK) {

View File

@ -53,8 +53,8 @@ constexpr float relativeTolerance = 1e-5;
constexpr float absoluteTolerance = 1e-8; constexpr float absoluteTolerance = 1e-8;
template <typename T> template <typename T>
float TensorSum(void *data, int size) { float TensorSum(const void *data, int size) {
T *typed_data = reinterpret_cast<T *>(data); const T *typed_data = reinterpret_cast<const T *>(data);
float sum = 0.f; float sum = 0.f;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
sum += static_cast<float>(typed_data[i]); sum += static_cast<float>(typed_data[i]);
@ -85,6 +85,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false);
} }
~NetTrainFlags() override = default; ~NetTrainFlags() override = default;
@ -117,6 +118,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
std::vector<std::vector<int>> resize_dims_; std::vector<std::vector<int>> resize_dims_;
std::string loss_name_ = ""; std::string loss_name_ = "";
std::string inference_file_ = ""; std::string inference_file_ = "";
bool unified_api_ = false;
}; };
class MS_API NetTrain { class MS_API NetTrain {
@ -126,49 +128,19 @@ class MS_API NetTrain {
int Init(); int Init();
int RunNetTrain(); int RunNetTrain();
static float *ReadFileBuf(const std::string file, size_t *size);
private: static int SetNr(std::function<int(NetTrainFlags *)> param);
// call GenerateInputData or ReadInputFile to init inputTensors static int RunNr(NetTrainFlags *flags) {
int LoadInput(Vector<tensor::MSTensor *> *ms_inputs); if (nr_cb_ != nullptr) {
void CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out); return nr_cb_(flags);
// call GenerateRandomData to fill inputTensors
int GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int GenerateRandomData(mindspore::tensor::MSTensor *tensor);
int ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs,
bool check_accuracy = true);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForInference(const std::string &filename,
const Context &context);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForTrain(const std::string &filename,
const std::string &bb_filename,
const Context &context, const TrainCfg &train_cfg,
int epochs);
int InitCallbackParameter();
int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);
template <typename T>
void PrintInputData(tensor::MSTensor *input) {
MS_ASSERT(input != nullptr);
static int i = 0;
auto inData = reinterpret_cast<T *>(input->MutableData());
size_t tensorSize = input->ElementsNum();
size_t len = (tensorSize < 20) ? tensorSize : 20;
std::cout << "InData" << i++ << ": ";
for (size_t j = 0; j < len; j++) {
std::cout << inData[j] << " ";
} }
std::cout << std::endl; MS_LOG(WARNING) << "unified api was not tested";
std::cout << "unified api was not tested";
return RET_OK;
} }
// tensorData need to be converter first // tensorData need to be converter first
template <typename T> template <typename T>
float CompareData(const float *refOutput, int size, T *msTensorData) { static float CompareData(const float *refOutput, int size, const T *msTensorData) {
size_t errorCount = 0; size_t errorCount = 0;
float meanError = 0; float meanError = 0;
std::cout << "Out tensor size is: " << size << std::endl; std::cout << "Out tensor size is: " << size << std::endl;
@ -219,6 +191,45 @@ class MS_API NetTrain {
return meanError; return meanError;
} }
private:
// call GenerateInputData or ReadInputFile to init inputTensors
int LoadInput(Vector<tensor::MSTensor *> *ms_inputs);
void CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out);
// call GenerateRandomData to fill inputTensors
int GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int GenerateRandomData(mindspore::tensor::MSTensor *tensor);
int ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs,
bool check_accuracy = true);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForInference(const std::string &filename,
const Context &context);
std::unique_ptr<session::LiteSession> CreateAndRunNetworkForTrain(const std::string &filename,
const std::string &bb_filename,
const Context &context, const TrainCfg &train_cfg,
int epochs);
int InitCallbackParameter();
int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);
template <typename T>
void PrintInputData(tensor::MSTensor *input) {
MS_ASSERT(input != nullptr);
static int i = 0;
auto inData = reinterpret_cast<T *>(input->MutableData());
size_t tensorSize = input->ElementsNum();
size_t len = (tensorSize < 20) ? tensorSize : 20;
std::cout << "InData" << i++ << ": ";
for (size_t j = 0; j < len; j++) {
std::cout << inData[j] << " ";
}
std::cout << std::endl;
}
int MarkPerformance(const std::unique_ptr<session::LiteSession> &session); int MarkPerformance(const std::unique_ptr<session::LiteSession> &session);
int MarkAccuracy(const std::unique_ptr<session::LiteSession> &session, bool enforce_accuracy = true); int MarkAccuracy(const std::unique_ptr<session::LiteSession> &session, bool enforce_accuracy = true);
int CompareOutput(const session::LiteSession &lite_session); int CompareOutput(const session::LiteSession &lite_session);
@ -242,8 +253,8 @@ class MS_API NetTrain {
} }
} }
#endif #endif
NetTrainFlags *flags_; NetTrainFlags *flags_{nullptr};
static std::function<int(NetTrainFlags *)> nr_cb_;
// callback parameters // callback parameters
uint64_t op_begin_ = 0; uint64_t op_begin_ = 0;
int op_call_times_total_ = 0; int op_call_times_total_ = 0;

View File

@ -101,6 +101,7 @@ Flags::Flags() {
"Whether to export MindIR pb. " "Whether to export MindIR pb. "
"true | false", "true | false",
"false"); "false");
AddFlag(&Flags::noFusionStr, "NoFusion", "Avoid fusion optimization true|false", "false");
} }
int Flags::InitInputOutputDataType() { int Flags::InitInputOutputDataType() {
@ -359,6 +360,18 @@ int Flags::InitPreInference() {
return RET_OK; return RET_OK;
} }
int Flags::InitNoFusion() {
if (this->noFusionStr == "true") {
this->disableFusion = true;
} else if (this->noFusionStr == "false") {
this->disableFusion = false;
} else {
std::cerr << "INPUT ILLEGAL: NoFusion must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitExportMindIR() { int Flags::InitExportMindIR() {
if (this->exportMindIR == "true") { if (this->exportMindIR == "true") {
this->export_mindir = true; this->export_mindir = true;
@ -495,12 +508,16 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "Init encrypt failed." << std::endl; std::cerr << "Init encrypt failed." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
ret = InitPreInference(); ret = InitPreInference();
if (ret != RET_OK) { if (ret != RET_OK) {
std::cerr << "Init pre inference failed." << std::endl; std::cerr << "Init pre inference failed." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
ret = InitNoFusion();
if (ret != RET_OK) {
std::cerr << "Init no fusion failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitExportMindIR(); ret = InitExportMindIR();
if (ret != RET_OK) { if (ret != RET_OK) {
@ -509,6 +526,7 @@ int Flags::Init(int argc, const char **argv) {
} }
return RET_OK; return RET_OK;
} }
Flags::~Flags() { Flags::~Flags() {
dec_key.clear(); dec_key.clear();
encKeyStr.clear(); encKeyStr.clear();
@ -591,7 +609,7 @@ std::string GetStrFromConfigFile(const std::string &file, const std::string &tar
} }
#ifdef _WIN32 #ifdef _WIN32
char *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit); auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
#else #else
char *real_path = realpath(file.c_str(), resolved_path.get()); char *real_path = realpath(file.c_str(), resolved_path.get());
#endif #endif

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_
#include <string> #include <string>
#include <vector> #include <vector>
@ -73,6 +73,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitSaveFP16(); int InitSaveFP16();
int InitNoFusion();
void InitAclDefaultOption(); void InitAclDefaultOption();
int InitExportMindIR(); int InitExportMindIR();
@ -90,6 +92,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
TypeId outputDataType; TypeId outputDataType;
std::string saveFP16Str = "off"; std::string saveFP16Str = "off";
bool saveFP16 = false; bool saveFP16 = false;
std::string noFusionStr = "false";
std::string inputDataTypeStr; std::string inputDataTypeStr;
std::string outputDataTypeStr; std::string outputDataTypeStr;
ParallelSplitConfig parallel_split_config_{}; ParallelSplitConfig parallel_split_config_{};
@ -134,4 +137,4 @@ std::string GetStrFromConfigFile(const std::string &file, const std::string &tar
} // namespace converter } // namespace converter
} // namespace mindspore } // namespace mindspore
#endif #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_