remove lite expression
This commit is contained in:
parent
3f313cf04a
commit
9414aee225
|
@ -24,38 +24,23 @@
|
|||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
class NetData;
|
||||
class Net;
|
||||
|
||||
class MS_API Graph {
|
||||
public:
|
||||
class GraphData;
|
||||
enum Type : uint32_t {
|
||||
kExpressionGraph = 0, ///< graph as expression - can auto grad
|
||||
kExecutableGraph = 1, ///< graph is loaded as is
|
||||
kUnknownTypeGraph = 0xffffffff
|
||||
};
|
||||
Graph();
|
||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||
explicit Graph(std::nullptr_t);
|
||||
~Graph();
|
||||
explicit Graph(Type executable);
|
||||
explicit Graph(Net *net);
|
||||
|
||||
enum ModelType ModelType() const;
|
||||
bool operator==(std::nullptr_t) const;
|
||||
bool operator!=(std::nullptr_t) const;
|
||||
bool IsExecutable() { return graph_type_ == kExecutableGraph; }
|
||||
|
||||
private:
|
||||
friend class GraphCell;
|
||||
friend class ModelImpl;
|
||||
friend class NetImpl;
|
||||
friend class Model;
|
||||
std::shared_ptr<GraphData> graph_data_;
|
||||
std::shared_ptr<NetData> net_data_;
|
||||
Type graph_type_ = kExecutableGraph;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
|
|
@ -33,9 +33,6 @@
|
|||
namespace mindspore {
|
||||
class ModelImpl;
|
||||
class Metrics;
|
||||
class Net;
|
||||
class Node;
|
||||
class Expr;
|
||||
|
||||
namespace dataset {
|
||||
class Dataset;
|
||||
|
@ -112,17 +109,6 @@ class MS_API Model {
|
|||
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
||||
|
||||
/// \brief Build train model
|
||||
///
|
||||
/// \param[in] graph A forward network
|
||||
/// \param[in] optimizer An optimizer node
|
||||
/// \param[in] inputs Inputs expression for the trained network (ex: input, label )
|
||||
/// \param[in] model_context A context used to store options during execution.
|
||||
/// \param[in] train_cfg A config used by training
|
||||
/// \return Status
|
||||
Status Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs,
|
||||
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg);
|
||||
|
||||
/// \brief Build a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
|
||||
///
|
||||
/// \param[in] backbone The static, non-learnable part of the graph
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INCLUDE_API_NET_H
|
||||
#define MINDSPORE_INCLUDE_API_NET_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/cfg.h"
|
||||
|
||||
namespace mindspore {
|
||||
/// \brief Register node or sub network
|
||||
#define REG(_name) Register(_name, #_name)
|
||||
|
||||
class Expr;
|
||||
class NodeImpl;
|
||||
class NetImpl;
|
||||
class NodeSet;
|
||||
class Graph;
|
||||
class NetData;
|
||||
|
||||
class MS_API NetBase {
|
||||
public:
|
||||
NetBase() = default;
|
||||
virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
|
||||
virtual uint32_t type() = 0;
|
||||
};
|
||||
|
||||
class MS_API Node : public NetBase {
|
||||
public:
|
||||
Node();
|
||||
virtual ~Node();
|
||||
/// \brief Create output expression from node
|
||||
|
||||
/// \param[in] name Name of input (like "labels" etc.)
|
||||
///
|
||||
/// \return Expression
|
||||
Expr *Create(std::string name);
|
||||
/// \brief Run node on inputs. This operator is used in Net::construct()
|
||||
///
|
||||
/// \param[in] inputs Inputs expression for the node.
|
||||
/// \return Output node expression vector
|
||||
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
|
||||
uint32_t type() final;
|
||||
|
||||
private:
|
||||
friend NodeImpl;
|
||||
std::shared_ptr<NodeImpl> impl_ = nullptr;
|
||||
};
|
||||
|
||||
class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
|
||||
public:
|
||||
Net();
|
||||
virtual ~Net();
|
||||
explicit Net(std::string name);
|
||||
explicit Net(const Graph &g);
|
||||
/// \brief Define the relation between network inputs and outputs
|
||||
///
|
||||
/// \param[in] inputs expression vector
|
||||
///
|
||||
/// \return expression vector
|
||||
|
||||
virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);
|
||||
/// \brief Addition operation
|
||||
///
|
||||
/// \param[in] inputs Two elements to add
|
||||
///
|
||||
/// \return expression vector (single element)
|
||||
|
||||
/// \brief Execution operator. Connect inputs to outputs via user defined construct
|
||||
///
|
||||
/// \return expression vector
|
||||
|
||||
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
|
||||
void Register(Net *net, std::string &&name);
|
||||
void Register(Node *node, std::string &&name);
|
||||
/// \brief Find the trainable params for the trained network
|
||||
///
|
||||
/// \return NodeSet for all trainable nodes
|
||||
std::shared_ptr<NodeSet> trainable_params();
|
||||
virtual void Add(NetBase *element);
|
||||
/// \brief Input shape
|
||||
///
|
||||
/// \param[in] idx input index
|
||||
///
|
||||
/// \return Specific input shape vector
|
||||
const std::vector<int> InputShape(int idx);
|
||||
/// \brief Output shape
|
||||
///
|
||||
/// \param[in] idx Output index
|
||||
///
|
||||
/// \return Specific output shape vector
|
||||
const std::vector<int> OutputShape(int idx);
|
||||
uint32_t type() final;
|
||||
|
||||
private:
|
||||
friend NetImpl;
|
||||
friend NetData;
|
||||
std::shared_ptr<NetImpl> impl_;
|
||||
};
|
||||
|
||||
class MS_API SoftMaxCrossEntropyCfg {
|
||||
public:
|
||||
std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */
|
||||
};
|
||||
|
||||
class MS_API AdamConfig {
|
||||
public:
|
||||
float learning_rate_ = 1e-3;
|
||||
float beta1_ = 0.9;
|
||||
float beta2_ = 0.999;
|
||||
float eps_ = 1e-08;
|
||||
bool use_nesterov_ = false;
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
MS_API Net *NetWithLoss(Net *net, Node *loss);
|
||||
MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
|
||||
MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
|
||||
MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
|
||||
MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
|
||||
int fmt = NHWC);
|
||||
}; // namespace NN
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_NET_H
|
|
@ -290,33 +290,9 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
)
|
||||
endif()
|
||||
|
||||
|
||||
file(GLOB CXX_API_EXPRESSION
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/expression/*.cc
|
||||
)
|
||||
|
||||
file(GLOB EXPRESSION_OPS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops/*.cc
|
||||
)
|
||||
|
||||
set(EXPRESSION_SRC
|
||||
${CXX_API_EXPRESSION}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/export.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/expr.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/import.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/net.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/node.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/ops_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/param.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expression/sequential.cc
|
||||
${EXPRESSION_OPS}
|
||||
)
|
||||
|
||||
set(TRAIN_SRC
|
||||
${API_TRAIN_SRC}
|
||||
${TRAIN_SRC_WITH_MD}
|
||||
${EXPRESSION_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor/kernel_exec.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/metrics/accuracy.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/train/model_build.cc
|
||||
|
@ -401,13 +377,6 @@ if(NOT MSLITE_ENABLE_COREML)
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/litert/delegate/coreml/stub/coreml_delegate_stub.cc)
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
${EXPRESSION_SRC}
|
||||
)
|
||||
endif()
|
||||
|
||||
add_subdirectory(litert/kernel/cpu)
|
||||
add_subdirectory(common)
|
||||
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ConvConfig {
|
||||
public:
|
||||
ConvConfig() = default;
|
||||
int in_channel_ = 3; /**< The channel number of the input of the Conv2d layer */
|
||||
int out_channel_ = 3; /**< The channel number of the output tensor of the Conv2d layer */
|
||||
std::vector<int64_t> kernel_size_ = {3, 3}; /**< Specifies the height and width of the 2D convolution kernel. */
|
||||
std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D convolution kernel */
|
||||
std::vector<int64_t> padding_ = {0, 0, 0, 0}; /**< The top, bottom, left, and right padding input */
|
||||
std::vector<int64_t> dilation_ = {1, 1}; /**< diletion height and width*/
|
||||
int group_ = 1; // < Splits filter into groups, `in_channels` and `out_channels` must be
|
||||
// divisible by `group`. If the group is equal to `in_channels` and `out_channels`,
|
||||
// this 2D convolution layer also can be called 2D depthwise convolution layer */
|
||||
bool has_bias = false; /** < Whether the Conv2d layer has a bias parameter */
|
||||
std::string weight_init_ =
|
||||
"normal"; /**< Initialization method of weight parameter ("normal","uniform", "ones", "zeros") */
|
||||
std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid", "pad" */
|
||||
|
||||
private:
|
||||
std::string bias_init_ = "zeros";
|
||||
std::string data_format;
|
||||
};
|
||||
|
||||
class DenseConfig {
|
||||
public:
|
||||
int in_channels_; /**< The number of channels in the input space */
|
||||
int out_channels_; /**< The number of channels in the output space */
|
||||
bool has_bias_ = false; /** Specifies whether the layer uses a bias vector **/
|
||||
private:
|
||||
std::string weight_init_ = "normal";
|
||||
std::string bias_init_ = "zeros";
|
||||
std::string activation_ = "none";
|
||||
};
|
||||
|
||||
class PoolingConfig {
|
||||
public:
|
||||
PoolingConfig() = default;
|
||||
std::vector<int64_t> kernel_size_ = {1, 1}; /**< Specifies the height and width of the 2D kernel. */
|
||||
std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D kernel */
|
||||
std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid" */
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_
|
|
@ -1,76 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include "src/expression/export.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr static int kFmkVal = 3;
|
||||
|
||||
int ExportSession::Init(const std::string model_name, std::string version) {
|
||||
meta_graph_ = new (std::nothrow) mindspore::schema::MetaGraphT();
|
||||
if (meta_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot allocate meta_graph";
|
||||
return RET_ERROR;
|
||||
}
|
||||
meta_graph_->fmkType = kFmkVal;
|
||||
meta_graph_->name = model_name;
|
||||
meta_graph_->version = version;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ExportSession::IsToDependOnly(EXPR *expr) {
|
||||
auto itr = outmap_.find(expr);
|
||||
if (itr != outmap_.end() && !itr->second.empty()) {
|
||||
for (auto expr : itr->second) {
|
||||
auto node = expr->node();
|
||||
if (node->primitive() != schema::PrimitiveType_Depend) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int ExportSession::SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs) {
|
||||
for (auto &in : inputs) {
|
||||
auto id = GetOutput(in);
|
||||
meta_graph_->inputIndex.push_back(id);
|
||||
}
|
||||
for (auto &out : outputs) {
|
||||
auto id = GetOutput(out);
|
||||
meta_graph_->outputIndex.push_back(id);
|
||||
}
|
||||
auto sub_graph = std::make_unique<mindspore::schema::SubGraphT>();
|
||||
if (sub_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot allocate SubGraphT";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto model_name = meta_graph_->name;
|
||||
sub_graph->name = model_name + "_subgraph";
|
||||
sub_graph->inputIndices = meta_graph_->inputIndex;
|
||||
sub_graph->outputIndices = meta_graph_->outputIndex;
|
||||
for (size_t i = 0; i < meta_graph_->nodes.size(); i++) sub_graph->nodeIndices.push_back(i);
|
||||
for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) sub_graph->tensorIndices.push_back(i);
|
||||
meta_graph_->subGraph.emplace_back(std::move(sub_graph));
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,52 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include "src/expression/expr.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace schema {
|
||||
struct MetaGraphT;
|
||||
}
|
||||
namespace lite {
|
||||
class ExportSession {
|
||||
public:
|
||||
explicit ExportSession(std::map<EXPR *, std::list<EXPR *>> &outmap) : outmap_(outmap) {}
|
||||
int Init(const std::string model_name, std::string version);
|
||||
void UpdateOutput(EXPR *expr, int id) { output_tensors_[expr] = id; }
|
||||
int GetOutput(EXPR *expr) { return output_tensors_.at(expr); }
|
||||
schema::MetaGraphT *&meta_graph() { return meta_graph_; }
|
||||
int SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs);
|
||||
bool IsToDependOnly(EXPR *expr);
|
||||
|
||||
private:
|
||||
schema::MetaGraphT *meta_graph_{nullptr};
|
||||
std::unordered_map<EXPR *, int> output_tensors_; // output tensors per EXPR
|
||||
std::map<EXPR *, std::list<EXPR *>> &outmap_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_
|
|
@ -1,98 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "src/expression/expr.h"
|
||||
#include "src/expression/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::string EXPR::name() { return node_->name(); }
|
||||
void EXPR::Travers(std::function<bool(EXPR *e, EXPR *itr)> cb) {
|
||||
if (!visited) {
|
||||
visited = true;
|
||||
for (auto &itr : params_) {
|
||||
if (cb(this, itr)) {
|
||||
itr->Travers(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EXPR::Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete) {
|
||||
if (!visited) {
|
||||
visited = true;
|
||||
for (auto &itr : params_)
|
||||
if (itr == *old) {
|
||||
to_delete->push_back(itr->node());
|
||||
itr = *n;
|
||||
}
|
||||
for (auto &itr : params_) itr->Replace(old, n, to_delete);
|
||||
}
|
||||
}
|
||||
|
||||
void EXPR::Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n) {
|
||||
std::vector<Node *> to_delete;
|
||||
for (auto &e : vec) {
|
||||
for (std::size_t i = 0; i < old->size(); i++) {
|
||||
e->Replace(&old->at(i), &n->at(i), &to_delete);
|
||||
}
|
||||
}
|
||||
for (auto &itr : to_delete) delete itr;
|
||||
for (auto e : vec) e->Clear();
|
||||
}
|
||||
|
||||
void EXPR::Clear() {
|
||||
EXPR *item = this;
|
||||
if (visited == false) return;
|
||||
visited = false;
|
||||
while (item->params_.size() == 1) {
|
||||
item = item->params_.front();
|
||||
if (item->visited == false) return;
|
||||
item->visited = false;
|
||||
}
|
||||
for (auto &itr : item->params_) itr->Clear();
|
||||
}
|
||||
|
||||
void EXPR::Clear(std::vector<EXPR *> vec) {
|
||||
for (auto e : vec) e->Clear();
|
||||
}
|
||||
|
||||
void EXPR::CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap) {
|
||||
for (auto e : vec) {
|
||||
e->Travers([&](EXPR *e, EXPR *itr) {
|
||||
(*outmap)[itr].push_back(e);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
Clear(vec);
|
||||
}
|
||||
|
||||
void EXPR::PrintDot(std::vector<EXPR *> vec) {
|
||||
std::cout << "digraph \"expr\" { " << std::endl;
|
||||
for (auto e : vec) {
|
||||
e->Travers([](EXPR *e, EXPR *itr) {
|
||||
std::cout << "\"" << itr->node_->name() << "\"->"
|
||||
<< "\"" << e->node_->name() << "\"" << std::endl;
|
||||
return true;
|
||||
});
|
||||
}
|
||||
std::cout << "}" << std::endl;
|
||||
Clear(vec);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,70 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include "include/api/format.h"
|
||||
#include "mindapi/base/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Node;
|
||||
|
||||
class EXPR {
|
||||
public:
|
||||
explicit EXPR(Node *node) : node_(node) { SetSize(1); }
|
||||
static void PrintDot(std::vector<EXPR *> vec);
|
||||
static void Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n);
|
||||
static void CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap);
|
||||
static void Clear(std::vector<EXPR *> vec);
|
||||
void Travers(std::function<bool(EXPR *e, EXPR *itr)> cb);
|
||||
std::string name();
|
||||
EXPR *GetInput(int idx) { return params_.at(idx); }
|
||||
void set_node(Node *node) { node_ = node; }
|
||||
Node *node() { return node_; }
|
||||
bool visited = false;
|
||||
void set_params(std::vector<EXPR *> params) { params_ = params; }
|
||||
void set_params(int idx, EXPR *expr) { params_[idx] = expr; }
|
||||
void add_params(EXPR *e) { params_.push_back(e); }
|
||||
std::vector<EXPR *> ¶ms() { return params_; }
|
||||
EXPR *params(int i) { return params_[i]; }
|
||||
void SetSize(int n) { params_.resize(n); }
|
||||
void SetDims(std::vector<int> dims) { dims_ = dims; }
|
||||
std::vector<int> &dims() { return dims_; }
|
||||
void set_format(int fmt) { format_ = fmt; }
|
||||
int format() { return format_; }
|
||||
void set_data_type(TypeId data_type) { data_type_ = data_type; }
|
||||
TypeId data_type() { return data_type_; }
|
||||
|
||||
private:
|
||||
void Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete);
|
||||
std::vector<EXPR *> params_;
|
||||
Node *node_{nullptr};
|
||||
void Clear();
|
||||
std::vector<int> dims_;
|
||||
int format_ = NHWC;
|
||||
TypeId data_type_ = kNumberTypeFloat32;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_
|
|
@ -1,180 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "src/expression/import.h"
|
||||
#include "common/ops/populate/populate_register.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/ops/activation.h"
|
||||
#include "src/expression/ops/batchnorm.h"
|
||||
#include "src/expression/ops/biasadd.h"
|
||||
#include "src/expression/ops/conv.h"
|
||||
#include "src/expression/ops/dense.h"
|
||||
#include "src/expression/ops/pooling.h"
|
||||
#include "src/expression/ops/reshape.h"
|
||||
#include "src/expression/ops/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::unordered_map<mindspore::schema::PrimitiveType, import_func> ImportReg::import_map_;
|
||||
|
||||
import_func ImportReg::GetImportFunc(mindspore::schema::PrimitiveType type) {
|
||||
auto f = import_map_.find(type);
|
||||
if (f == import_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return f->second;
|
||||
}
|
||||
|
||||
OpParameter *Import::GetAttr(const schema::Primitive *prim) {
|
||||
auto parameter_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
|
||||
if (parameter_gen == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
|
||||
return nullptr;
|
||||
}
|
||||
auto parameter = parameter_gen(prim);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
return parameter;
|
||||
}
|
||||
|
||||
std::unique_ptr<Node> Import::CreateNode(const schema::CNode *cnode) {
|
||||
auto param = GetAttr(cnode->primitive());
|
||||
auto type = cnode->primitive()->value_type();
|
||||
auto fn = ImportReg::GetImportFunc(type);
|
||||
if (fn == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot find importer for " << schema::EnumNamePrimitiveType(type);
|
||||
return nullptr;
|
||||
}
|
||||
auto node = fn();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate node" << cnode->name()->str();
|
||||
return nullptr;
|
||||
}
|
||||
node->SetOpParam(param);
|
||||
node->set_name(cnode->name()->str());
|
||||
node->set_primitive(type);
|
||||
return std::unique_ptr<Node>(node);
|
||||
}
|
||||
|
||||
Net *Import::ImportMs(std::string file_name) {
|
||||
std::ifstream infile;
|
||||
infile.open(file_name, std::ios::binary | std::ios::in);
|
||||
if (!infile.good()) {
|
||||
MS_LOG(ERROR) << "cannot read " << file_name << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
infile.seekg(0, std::ios::end);
|
||||
int length = infile.tellg();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
auto data_ptr = std::make_unique<int8_t[]>(length);
|
||||
auto *data = data_ptr.get();
|
||||
infile.read(reinterpret_cast<char *>(data), length);
|
||||
infile.close();
|
||||
flatbuffers::Verifier verifier = flatbuffers::Verifier(reinterpret_cast<const uint8_t *>(data), length);
|
||||
bool res = schema::VerifyMetaGraphBuffer(verifier);
|
||||
if (res != true) {
|
||||
MS_LOG(ERROR) << "fault file: " << file_name << "(" << length << ")\n";
|
||||
return nullptr;
|
||||
} else {
|
||||
MS_LOG(INFO) << "verify pass file: " << file_name << "(" << length << ")\n";
|
||||
}
|
||||
buffer_ = data_ptr.get();
|
||||
auto metaGraph = schema::GetMetaGraph(data_ptr.release());
|
||||
return ImportMs(metaGraph);
|
||||
}
|
||||
|
||||
Net *Import::ImportMs(const schema::MetaGraph *metaGraph) {
|
||||
if (metaGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "null input";
|
||||
return nullptr;
|
||||
}
|
||||
std::string NetName = "Network";
|
||||
if (metaGraph->name() != nullptr) NetName = metaGraph->name()->str();
|
||||
auto net = std::make_unique<Net>(NetName);
|
||||
std::unordered_map<int, EXPR *> outputs;
|
||||
// save inputs
|
||||
for (size_t i = 0; i < metaGraph->inputIndex()->size(); i++) {
|
||||
auto tensor_id = metaGraph->inputIndex()->Get(i);
|
||||
const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id);
|
||||
auto input = new (std::nothrow) InputM(tensor);
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate input";
|
||||
return nullptr;
|
||||
}
|
||||
auto e = input->expr();
|
||||
outputs[tensor_id] = e;
|
||||
net->PushInput(e);
|
||||
}
|
||||
for (size_t i = 0; i < metaGraph->nodes()->size(); i++) {
|
||||
auto Cnode = metaGraph->nodes()->Get(i);
|
||||
std::vector<EXPR *> param_tensors;
|
||||
for (size_t j = 0; j < Cnode->inputIndex()->size(); j++) {
|
||||
int tensor_id = Cnode->inputIndex()->Get(j);
|
||||
const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id);
|
||||
auto iter = outputs.find(tensor_id);
|
||||
if (iter == outputs.end()) {
|
||||
// create value node if not exist
|
||||
if (tensor->nodeType() != NodeType::NodeType_CNode) {
|
||||
auto valnode = new (std::nothrow) InputM(tensor);
|
||||
if (valnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate valnode";
|
||||
return nullptr;
|
||||
}
|
||||
outputs[tensor_id] = valnode->expr();
|
||||
param_tensors.push_back(valnode->expr());
|
||||
net->PushOp(valnode);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "did not found input tensor " << tensor_id;
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
param_tensors.push_back(iter->second);
|
||||
}
|
||||
}
|
||||
// create expression from node //
|
||||
auto node = CreateNode(Cnode);
|
||||
if (node != nullptr) {
|
||||
node->SetOutputs(Cnode->outputIndex()->size());
|
||||
std::vector<EXPR *> e = (*node)(param_tensors);
|
||||
for (size_t j = 0; j < Cnode->outputIndex()->size(); j++) {
|
||||
int tensor_id = Cnode->outputIndex()->Get(j);
|
||||
outputs[tensor_id] = e.at(j);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "failed to create node " << Cnode->name();
|
||||
return nullptr;
|
||||
}
|
||||
auto node_ptr = node.release();
|
||||
net->PushOp(node_ptr);
|
||||
node_ptr->SetLearn();
|
||||
}
|
||||
for (size_t i = 0; i < metaGraph->outputIndex()->size(); i++) {
|
||||
auto tensor_id = metaGraph->outputIndex()->Get(i);
|
||||
auto iter = outputs.find(tensor_id);
|
||||
if (iter == outputs.end()) {
|
||||
MS_LOG(ERROR) << "could not find source for tensor " << tensor_id;
|
||||
return nullptr;
|
||||
} else {
|
||||
net->PushOutput(iter->second);
|
||||
}
|
||||
}
|
||||
return net.release();
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/expression/net.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using import_func = std::function<Node *()>;
|
||||
|
||||
template <typename T>
|
||||
Node *ReturnNode() {
|
||||
return new (std::nothrow) T();
|
||||
}
|
||||
|
||||
class ImportReg {
|
||||
public:
|
||||
explicit ImportReg(mindspore::schema::PrimitiveType type, import_func func) { import_map_[type] = func; }
|
||||
static import_func GetImportFunc(mindspore::schema::PrimitiveType type);
|
||||
|
||||
private:
|
||||
static std::unordered_map<mindspore::schema::PrimitiveType, import_func> import_map_;
|
||||
};
|
||||
|
||||
class Import {
|
||||
private:
|
||||
int8_t *buffer_ = nullptr;
|
||||
OpParameter *GetAttr(const schema::Primitive *prim);
|
||||
std::unique_ptr<Node> CreateNode(const schema::CNode *cnode);
|
||||
|
||||
public:
|
||||
Net *ImportMs(const schema::MetaGraph *meta_graph);
|
||||
Net *ImportMs(std::string file_name);
|
||||
~Import() {
|
||||
delete[] buffer_;
|
||||
buffer_ = nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_
|
|
@ -1,268 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/net.h"
|
||||
#include <vector>
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/export.h"
|
||||
#include "src/expression/ops/addn.h"
|
||||
#include "src/expression/ops/arithmetic.h"
|
||||
#include "src/common/storage.h"
|
||||
#include "tools/common/meta_graph_serializer.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void Net::update_name(std::string name) {
|
||||
if (!this->name().empty())
|
||||
Node::update_name(name);
|
||||
else
|
||||
set_name(name);
|
||||
for (auto &itr : ops_) {
|
||||
itr->update_name(name);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<EXPR *> Net::operator()(const std::initializer_list<EXPR *> &&inputs) {
|
||||
std::vector<EXPR *> vec = inputs;
|
||||
std::vector<EXPR *> x;
|
||||
if (impl_ == nullptr) {
|
||||
x = construct(inputs);
|
||||
} else {
|
||||
x = impl_->construct(vec);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> Net::operator()(const std::vector<EXPR *> &inputs) {
|
||||
std::vector<EXPR *> x;
|
||||
if (impl_ == nullptr) {
|
||||
x = construct(inputs);
|
||||
} else {
|
||||
x = impl_->construct(inputs);
|
||||
}
|
||||
input_ = inputs;
|
||||
output_ = x;
|
||||
real_output_ = x;
|
||||
return x;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> Net::construct(const std::vector<EXPR *> &inputs) {
|
||||
if (!output_.empty()) {
|
||||
if (input_.size() != inputs.size()) {
|
||||
MS_LOG(ERROR) << "input size mismatch, should be " << input_.size() << " got " << inputs.size();
|
||||
return {};
|
||||
}
|
||||
auto in_ptr = inputs;
|
||||
EXPR::Replace(output_, &input_, &in_ptr);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "no network construction function";
|
||||
}
|
||||
return output_;
|
||||
}
|
||||
|
||||
void Net::TopoSortUtil(Node *node, std::stack<Node *> *stack) {
|
||||
visited_.insert(node);
|
||||
for (size_t i = 0; i < node->OutputsNum(); i++) {
|
||||
auto expr = node->expr(i);
|
||||
auto itr = outmap_.find(expr);
|
||||
if (itr != outmap_.end()) {
|
||||
for (auto &e : itr->second)
|
||||
if (visited_.find(e->node()) == visited_.end()) {
|
||||
TopoSortUtil(e->node(), stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
stack->push(node);
|
||||
}
|
||||
|
||||
std::vector<Node *> Net::Sort() {
|
||||
std::stack<Node *> stack;
|
||||
outmap_.clear();
|
||||
EXPR::CreateOutputMap(output_, &outmap_);
|
||||
for (auto &itr : outmap_) {
|
||||
EXPR *e = itr.first;
|
||||
if (visited_.find(e->node()) == visited_.end()) {
|
||||
TopoSortUtil(e->node(), &stack);
|
||||
}
|
||||
}
|
||||
std::vector<Node *> res;
|
||||
while (stack.empty() == false) {
|
||||
res.push_back(stack.top());
|
||||
stack.pop();
|
||||
}
|
||||
visited_.clear();
|
||||
return res;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MetaGraphT> Net::MakeMs() {
|
||||
auto nodes = Sort();
|
||||
auto s = new (std::nothrow) ExportSession(outmap_);
|
||||
if (s == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate export session";
|
||||
return nullptr;
|
||||
}
|
||||
session_.reset(s);
|
||||
session_->Init(name(), Version());
|
||||
for (auto node : nodes) {
|
||||
auto res = node->MakeEntry(session_.get());
|
||||
if (res != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed in MakeEntry: " << node->name();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
session_->SetInputOutput(input_, real_output_);
|
||||
auto res = session_->meta_graph();
|
||||
return std::unique_ptr<schema::MetaGraphT>(res);
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MetaGraphT> Net::MakeMs(const std::string file_name) {
|
||||
auto graph = MakeMs();
|
||||
Save(*graph, file_name);
|
||||
return graph;
|
||||
}
|
||||
|
||||
std::set<Node *> Net::trainable_params() {
|
||||
std::set<Node *> res;
|
||||
for (auto &node : ops_) {
|
||||
res.merge(node->trainable_params());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int Net::BuildGrad(Node *optimizer) {
|
||||
std::set<Node *> learn = optimizer->trainable_params();
|
||||
auto NetOrder = Sort();
|
||||
optimizer_.reset(optimizer);
|
||||
optimizer->AddNetOutput(&output_);
|
||||
std::map<std::pair<EXPR *, EXPR *>, EXPR *> backprop;
|
||||
for (auto itr = NetOrder.rbegin(); itr != NetOrder.rend(); itr++) {
|
||||
Node *node = *itr;
|
||||
EXPR *yt = nullptr;
|
||||
if (node->primitive() == schema::PrimitiveType_NONE) continue;
|
||||
if (outmap_.find(node->expr()) == outmap_.end() || outmap_[node->expr()].size() == 0) {
|
||||
yt = node->expr();
|
||||
} else {
|
||||
std::vector<EXPR *> add_params;
|
||||
for (auto &output : outmap_[node->expr()]) {
|
||||
auto link = std::make_pair(node->expr(), output);
|
||||
auto grad = backprop[link];
|
||||
add_params.push_back(grad);
|
||||
}
|
||||
if (add_params.size() == 1) {
|
||||
yt = add_params.front();
|
||||
} else {
|
||||
auto addn = new (std::nothrow) AddN(0);
|
||||
if (addn == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate add operator";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PushOp(addn);
|
||||
addn->update_name(name());
|
||||
yt = (*addn)(add_params).front();
|
||||
}
|
||||
}
|
||||
auto inGrads = node->Grad(yt);
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
EXPR *inGrad{nullptr};
|
||||
if (i < inGrads.size()) {
|
||||
inGrad = inGrads[i];
|
||||
} else {
|
||||
inGrad = nullptr;
|
||||
}
|
||||
auto input = node->input(i);
|
||||
if (learn.find(input->node()) != learn.end()) {
|
||||
auto opt = optimizer->Clone(inGrad, input);
|
||||
if (opt.size() == 0) {
|
||||
MS_LOG(ERROR) << "failed to create optimizer";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "illegal null value for grad";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (opt.size() == 0) {
|
||||
MS_LOG(ERROR) << "optimizer for " << input->node()->name() << " failure";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto opt_op = opt.at(0)->node();
|
||||
PushOp(opt_op);
|
||||
opt_op->update_name(node->name());
|
||||
output_.push_back(opt.at(0));
|
||||
}
|
||||
auto link = std::make_pair(input, node->expr());
|
||||
backprop[link] = inGrad;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> Net::add(const std::vector<EXPR *> &input) {
|
||||
auto _add = NN::Add();
|
||||
_add->set_name(name() + "/" + _add->name());
|
||||
ops_.push_back(_add);
|
||||
return (*_add)(input);
|
||||
}
|
||||
|
||||
Net *Net::TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs) {
|
||||
auto net = new (std::nothrow) NetWithLoss(this, loss_fn);
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate loss network";
|
||||
return nullptr;
|
||||
}
|
||||
return net->TrainNet(optimizer, inputs);
|
||||
}
|
||||
|
||||
Net *Net::TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs) {
|
||||
auto x = (*this)(inputs);
|
||||
auto res = BuildGrad(optimizer);
|
||||
if (res != RET_OK) {
|
||||
MS_LOG(ERROR) << "Build gradient network failed";
|
||||
return nullptr;
|
||||
}
|
||||
real_output_ = x;
|
||||
return this;
|
||||
}
|
||||
|
||||
int Net::Save(const schema::MetaGraphT &graph, std::string file_name) { return Storage::Save(graph, file_name); }
|
||||
|
||||
const std::vector<int> Net::OutputShape(int idx) {
|
||||
if (static_cast<size_t>(idx) >= real_output_.size()) {
|
||||
MS_LOG(ERROR) << "index (" << idx << ") exceed output size (" << real_output_.size() << ")";
|
||||
return {};
|
||||
}
|
||||
return real_output_.at(idx)->dims();
|
||||
}
|
||||
|
||||
const std::vector<int> Net::InputShape(int idx) {
|
||||
if (static_cast<size_t>(idx) >= input_.size()) {
|
||||
MS_LOG(ERROR) << "index (" << idx << ") exceed input size (" << input_.size() << ")";
|
||||
return {};
|
||||
}
|
||||
return input_.at(idx)->dims();
|
||||
}
|
||||
|
||||
Net::~Net() {
|
||||
if (impl_ != nullptr) {
|
||||
impl_->erase_net();
|
||||
auto pnet = impl_->pnet();
|
||||
if (pnet != nullptr) {
|
||||
impl_->set_pnet(nullptr);
|
||||
}
|
||||
}
|
||||
impl_ = nullptr;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,114 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NET_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_NET_H_
|
||||
#include <stack>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
class Net;
|
||||
class NetImpl;
|
||||
namespace lite {
|
||||
#define REG(_name) Register(_name, #_name)
|
||||
|
||||
class ExportSession;
|
||||
|
||||
class Net : public Node {
|
||||
public:
|
||||
Net() = default;
|
||||
virtual ~Net();
|
||||
explicit Net(std::string name) : Node(name) {}
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) override;
|
||||
std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) override;
|
||||
void update_name(std::string name) override;
|
||||
Net *TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs);
|
||||
Net *TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs);
|
||||
void PrintDot() { EXPR::PrintDot(output_); }
|
||||
|
||||
void PushOutput(EXPR *e) { output_.push_back(e); }
|
||||
void PushInput(EXPR *e) { input_.push_back(e); }
|
||||
void SetRealOutput() { real_output_ = output_; }
|
||||
std::set<Node *> trainable_params() override;
|
||||
std::vector<Node *> Sort();
|
||||
int BuildGrad(Node *optimizer);
|
||||
int BuildGrad(Node *optimizer, std::set<Node *> learnable);
|
||||
std::unique_ptr<schema::MetaGraphT> MakeMs();
|
||||
std::unique_ptr<schema::MetaGraphT> MakeMs(std::string file_name);
|
||||
schema::MetaGraph *meta_graph() { return meta_graph_; }
|
||||
int Save(const schema::MetaGraphT &graph, const std::string filename);
|
||||
void set_impl(std::shared_ptr<mindspore::NetImpl> impl) { impl_ = impl; }
|
||||
const std::vector<int> InputShape(int idx);
|
||||
const std::vector<int> OutputShape(int idx);
|
||||
|
||||
protected:
|
||||
std::vector<EXPR *> add(const std::vector<EXPR *> &input);
|
||||
void Register(Node *node, std::string &&name) {
|
||||
if (node != nullptr) {
|
||||
PushOp(node);
|
||||
node->update_name(name);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
friend mindspore::Net;
|
||||
std::unordered_set<Node *> visited_;
|
||||
std::map<EXPR *, std::list<EXPR *>> outmap_; // outputs per expression
|
||||
std::map<EXPR *, std::list<EXPR *>> inmap_; // inputs per expression
|
||||
std::vector<EXPR *> output_; // network output expression
|
||||
std::vector<EXPR *> real_output_; // network output for export
|
||||
std::vector<EXPR *> input_; // network input expression
|
||||
schema::MetaGraph *meta_graph_; // imported meta_graph
|
||||
std::unique_ptr<ExportSession> session_; // export session
|
||||
std::unique_ptr<Node> optimizer_;
|
||||
void TopoSortUtil(Node *v, std::stack<Node *> *stack);
|
||||
void CreateOutputMap(std::vector<EXPR *> vec, std::map<Node *, std::list<Node *>> *outmap);
|
||||
std::shared_ptr<mindspore::NetImpl> impl_;
|
||||
};
|
||||
|
||||
class NetWithLoss : public Net {
|
||||
public:
|
||||
NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) {
|
||||
REG(net_);
|
||||
REG(loss_fn_);
|
||||
loss_fn_->set_name("_loss_fn");
|
||||
}
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) {
|
||||
auto input = inputs[0];
|
||||
auto label = inputs[1];
|
||||
auto x = (*net_)({input});
|
||||
x = (*loss_fn_)({x[0], label});
|
||||
return {x.front()};
|
||||
}
|
||||
|
||||
private:
|
||||
Net *net_{nullptr};
|
||||
Node *loss_fn_{nullptr};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_NET_H_
|
|
@ -1,271 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "src/expression/node.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/export.h"
|
||||
#include "src/litert/infer_manager.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int Node::name_id;
|
||||
|
||||
std::vector<EXPR *> Node::construct(const std::vector<EXPR *> &inputs) {
|
||||
if (inputs.size() >= expr()->params().size()) {
|
||||
expr()->set_params(inputs);
|
||||
} else {
|
||||
for (std::size_t i = 0; i < inputs.size(); i++) {
|
||||
expr()->set_params(i, inputs[i]);
|
||||
}
|
||||
}
|
||||
auto ret = InferShape();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "error infershape for node " << name();
|
||||
return {};
|
||||
}
|
||||
std::vector<EXPR *> res(expr_.size());
|
||||
(void)std::transform(expr_.begin(), expr_.end(), res.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); });
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> Node::Grad(EXPR *expr) {
|
||||
MS_LOG(ERROR) << name() << " (" << schema::EnumNamePrimitiveType(primitive()) << ") does not have grad defined";
|
||||
return {};
|
||||
}
|
||||
|
||||
int Node::CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input) {
|
||||
MS_ASSERT(tensors != nullptr);
|
||||
int ret = RET_OK;
|
||||
for (auto e : expr) {
|
||||
// Tensor -> TensorC
|
||||
if (is_input && e->node()->primitive() == schema::PrimitiveType_Depend) {
|
||||
continue;
|
||||
}
|
||||
auto type = (e->node()->primitive() != schema::PrimitiveType_NONE) ? Category::VAR : Category::CONST_TENSOR;
|
||||
auto t = std::make_unique<Tensor>(e->data_type(), e->dims(), (mindspore::Format)e->format(), type);
|
||||
if (t == nullptr) {
|
||||
ret = RET_NULL_PTR;
|
||||
break;
|
||||
}
|
||||
// copy data if any
|
||||
if (type == Category::CONST_TENSOR) {
|
||||
void *dst = t->MutableData();
|
||||
if (dst == nullptr) {
|
||||
ret = RET_NULL_PTR;
|
||||
break;
|
||||
}
|
||||
if (e->node()->data() && (e->node()->data()->data().size() > 0)) {
|
||||
uint8_t *src = e->node()->data()->data().data();
|
||||
memcpy(dst, src, t->Size());
|
||||
}
|
||||
}
|
||||
tensors->push_back(t.release());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Node::FreeAllTensors(std::vector<Tensor *> *tensors) {
|
||||
MS_ASSERT(tensors != nullptr);
|
||||
for (auto &t : *tensors) {
|
||||
delete t;
|
||||
}
|
||||
tensors->clear();
|
||||
}
|
||||
|
||||
int Node::InferShape() {
|
||||
auto ret = RET_OK;
|
||||
std::vector<Tensor *> in_tensors;
|
||||
std::vector<Tensor *> out_tensors;
|
||||
// build in \ out tensors
|
||||
ret = CreateTensorFromExpr(expr()->params(), &in_tensors, true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed in create in tensors";
|
||||
FreeAllTensors(&in_tensors);
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<EXPR *> expr(expr_.size());
|
||||
(void)std::transform(expr_.begin(), expr_.end(), expr.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); });
|
||||
ret = CreateTensorFromExpr(expr, &out_tensors);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed in create out tensors";
|
||||
FreeAllTensors(&in_tensors);
|
||||
FreeAllTensors(&out_tensors);
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Do infer Shape
|
||||
ret = KernelInferShape(in_tensors, out_tensors, OpParam());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed in infer shape for " << name();
|
||||
FreeAllTensors(&in_tensors);
|
||||
FreeAllTensors(&out_tensors);
|
||||
return RET_ERROR;
|
||||
}
|
||||
// copy infer shape into expr
|
||||
for (uint32_t i = 0; i < expr_.size(); i++) {
|
||||
auto e = &expr_.at(i);
|
||||
auto o = out_tensors.at(i);
|
||||
e->set_format((o->format()));
|
||||
e->set_data_type(o->data_type());
|
||||
e->SetDims(o->shape());
|
||||
}
|
||||
// cleanup
|
||||
FreeAllTensors(&in_tensors);
|
||||
FreeAllTensors(&out_tensors);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
EXPR *Node::CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name) {
|
||||
auto weights = new (std::nothrow) InputM(dims);
|
||||
if (weights == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate weights";
|
||||
return nullptr;
|
||||
}
|
||||
weights->set_name(this->name() + "/" + name);
|
||||
int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>());
|
||||
weights->data()->SetSize(size);
|
||||
weights->data()->Fill(mode);
|
||||
PushOp(weights);
|
||||
return weights->expr();
|
||||
}
|
||||
|
||||
Node *Node::CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name,
|
||||
const void *data) {
|
||||
auto tensor = NN::Input(dims, data_type, format);
|
||||
int elem_size = DataTypeSize(data_type);
|
||||
tensor->set_name(this->name() + "/" + name);
|
||||
int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()) * elem_size;
|
||||
tensor->data()->SetSize(size);
|
||||
tensor->data()->Copy(reinterpret_cast<const uint8_t *>(data), size);
|
||||
expr()->set_params(index, tensor->expr());
|
||||
PushOp(tensor);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
int Node::MakeEntry(ExportSession *session) {
|
||||
std::vector<uint32_t> input_idx;
|
||||
std::vector<uint32_t> output_idx;
|
||||
std::vector<uint8_t> empty;
|
||||
if (primitive() == schema::PrimitiveType_Depend) return RET_OK;
|
||||
// create node input
|
||||
size_t inputs = InputsNum();
|
||||
for (size_t i = 0; i < inputs; i++) {
|
||||
EXPR *ex = expr()->GetInput(i);
|
||||
if (ex->node()->primitive() == schema::PrimitiveType_Depend) continue;
|
||||
uint32_t id = session->GetOutput(ex);
|
||||
input_idx.push_back(id);
|
||||
}
|
||||
size_t outputs = OutputsNum();
|
||||
size_t last_id = session->meta_graph()->allTensors.size();
|
||||
int type = (primitive() == schema::PrimitiveType_NONE) ? NodeType_ValueNode : NodeType_CNode;
|
||||
auto data = (type == NodeType_ValueNode) ? this->data()->data() : empty;
|
||||
if (data.empty()) type = NodeType_CNode; // input is Cnode !!?
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < outputs; i++) {
|
||||
if (session->IsToDependOnly(expr(i))) continue;
|
||||
output_idx.push_back(last_id + idx);
|
||||
session->UpdateOutput(expr(i), last_id + idx);
|
||||
auto odims = dims(i);
|
||||
auto data_type = expr(i)->data_type();
|
||||
auto format = expr(i)->format();
|
||||
std::string footer = (i > 0) ? ("-" + std::to_string(i)) : "";
|
||||
auto otensor = CreateTensor(name() + footer, type, data_type, odims, format, data);
|
||||
std::cout << "tensor -" << last_id + idx << ": " << name() + footer << std::endl;
|
||||
idx++;
|
||||
session->meta_graph()->allTensors.emplace_back(std::move(otensor));
|
||||
}
|
||||
if (primitive() != schema::PrimitiveType_NONE) {
|
||||
if (output_idx.size() == 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto cnode = CreateCNode(input_idx, output_idx);
|
||||
|
||||
auto ret = UnPopulate(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to populate cnode";
|
||||
return RET_ERROR;
|
||||
}
|
||||
session->meta_graph()->nodes.emplace_back(std::move(cnode));
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CNodeT> Node::CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex) {
|
||||
auto cnode = std::make_unique<schema::CNodeT>();
|
||||
cnode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
cnode->primitive->value.type = primitive();
|
||||
cnode->name = name();
|
||||
cnode->inputIndex = inputIndex;
|
||||
cnode->outputIndex = outputIndex;
|
||||
return cnode;
|
||||
}
|
||||
|
||||
int Node::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
MS_LOG(ERROR) << "Node " << schema::EnumNamePrimitiveType(primitive()) << " cannot be exported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::unique_ptr<mindspore::schema::TensorT> Node::CreateTensor(std::string name, int type, int data_type,
|
||||
const std::vector<int32_t> dims, int format,
|
||||
const std::vector<uint8_t> &data) {
|
||||
auto tensorT = std::make_unique<mindspore::schema::TensorT>();
|
||||
tensorT->nodeType = type;
|
||||
tensorT->dims = dims;
|
||||
tensorT->format = static_cast<schema::Format>(format);
|
||||
tensorT->name = name;
|
||||
tensorT->refCount = 0;
|
||||
tensorT->offset = 0;
|
||||
tensorT->dataType = data_type;
|
||||
tensorT->data = data;
|
||||
tensorT->enableHuffmanCode = false;
|
||||
if (tensorT->nodeType == mindspore::lite::NodeType_ValueNode) {
|
||||
tensorT->data = data;
|
||||
}
|
||||
return tensorT;
|
||||
}
|
||||
|
||||
int Node::SetOutputs(int num) {
|
||||
EXPR e(this);
|
||||
e.SetSize(0);
|
||||
for (auto i = expr_.size(); i < static_cast<size_t>(num); i++) {
|
||||
expr_.emplace_back(e);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Node::~Node() {
|
||||
for (auto &op : ops_) {
|
||||
delete op;
|
||||
}
|
||||
ops_.clear();
|
||||
if (impl_ != nullptr) {
|
||||
impl_->set_node(nullptr);
|
||||
auto pnode = impl_->pnode();
|
||||
if (pnode != nullptr) {
|
||||
impl_->set_pnode(nullptr);
|
||||
delete pnode;
|
||||
}
|
||||
}
|
||||
impl_ = nullptr;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,156 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "src/expression/export.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/param.h"
|
||||
#include "src/expression/expr.h"
|
||||
#include "src/tensor.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
class NodeImpl;
|
||||
namespace schema {
|
||||
struct TensorT;
|
||||
struct CNodeT;
|
||||
} // namespace schema
|
||||
|
||||
namespace lite {
|
||||
class Node {
|
||||
public:
|
||||
const std::string kGradName = "Gradients";
|
||||
explicit Node(const std::string name) : opParam_(nullptr), name_(name) { expr_.emplace_back(this); }
|
||||
virtual ~Node();
|
||||
Node() : Node("") {}
|
||||
explicit Node(Node *node) : Node(*node) {}
|
||||
EXPR *create(std::string name) {
|
||||
name_ = name;
|
||||
return &expr_[0];
|
||||
}
|
||||
virtual std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) {
|
||||
auto x = construct(inputs);
|
||||
return x;
|
||||
}
|
||||
virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) {
|
||||
std::vector<EXPR *> vec = inputs;
|
||||
auto x = construct(vec);
|
||||
return x;
|
||||
}
|
||||
virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &inputs) {
|
||||
std::vector<EXPR *> vec = inputs;
|
||||
auto x = construct(vec);
|
||||
return x;
|
||||
}
|
||||
void set_primitive(schema::PrimitiveType primitive) {
|
||||
primitive_ = primitive;
|
||||
if (OpParam() != nullptr) opParam_->type_ = primitive_;
|
||||
}
|
||||
schema::PrimitiveType primitive() { return primitive_; }
|
||||
virtual std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs);
|
||||
std::string name() { return name_; }
|
||||
void set_name(std::string name) { name_ = name; }
|
||||
virtual void update_name(std::string name) { set_name(name + "/" + name_); }
|
||||
size_t Load(std::string file_name, size_t offset = 0) { return offset; }
|
||||
OpParameter *OpParam() const { return opParam_.get(); }
|
||||
virtual void Add(Node *node) {}
|
||||
virtual std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) { return {}; }
|
||||
void SetOpParam(std::shared_ptr<OpParameter> opParam) { opParam_ = opParam; }
|
||||
void SetOpParam(void *opParam) { opParam_.reset(reinterpret_cast<OpParameter *>(opParam), free); }
|
||||
static std::string UniqueName(const std::string &name) { return name + "-" + std::to_string(name_id++); }
|
||||
static std::string UniqueName(std::string &&name) { return name + "-" + std::to_string(name_id++); }
|
||||
template <typename T>
|
||||
int CloneOpParam(std::shared_ptr<OpParameter> opParam) {
|
||||
auto t = reinterpret_cast<T *>(opParam.get());
|
||||
auto obj = new (std::nothrow) T(*t); // copy content
|
||||
if (obj == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate obj";
|
||||
return RET_ERROR;
|
||||
}
|
||||
opParam_.reset(reinterpret_cast<OpParameter *>(obj));
|
||||
return RET_OK;
|
||||
}
|
||||
template <typename T>
|
||||
int CloneOpParam(OpParameter *opParam) {
|
||||
auto t = reinterpret_cast<T *>(opParam);
|
||||
auto obj = new (std::nothrow) T(*t); // copy content
|
||||
if (obj == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate obj";
|
||||
return RET_ERROR;
|
||||
}
|
||||
opParam_.reset(reinterpret_cast<OpParameter *>(obj));
|
||||
return RET_OK;
|
||||
}
|
||||
virtual Param *weight() { return nullptr; }
|
||||
EXPR *expr(int i) { return &expr_[i]; }
|
||||
EXPR *expr() { return expr(0); }
|
||||
std::vector<EXPR *> inputs() { return expr()[0].params(); }
|
||||
size_t InputsNum() { return expr()[0].params().size(); }
|
||||
size_t OutputsNum() { return expr_.size(); }
|
||||
EXPR *input(int idx) { return expr()[0].params().at(idx); }
|
||||
EXPR *output(int idx) { return expr(idx); }
|
||||
EXPR *CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name);
|
||||
Node *CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name,
|
||||
const void *data);
|
||||
virtual std::vector<EXPR *> Grad(EXPR *expr);
|
||||
virtual Param *data() { return nullptr; }
|
||||
bool IsLearn(Node *node) { return learnable_.find(node) != learnable_.end(); }
|
||||
virtual void SetLearn() {}
|
||||
virtual std::set<Node *> trainable_params() { return learnable_; }
|
||||
std::vector<int> &dims() { return expr()->dims(); }
|
||||
std::vector<int> &dims(int i) { return expr(i)->dims(); }
|
||||
// export
|
||||
int MakeEntry(ExportSession *session);
|
||||
void PushOp(Node *n) { ops_.push_back(n); }
|
||||
virtual void AddNetOutput(std::vector<EXPR *> *output) {}
|
||||
int SetOutputs(int num);
|
||||
std::shared_ptr<OpParameter> opParam_;
|
||||
void set_impl(std::shared_ptr<NodeImpl> impl) { impl_ = impl; }
|
||||
|
||||
protected:
|
||||
std::vector<EXPR> expr_; // hold outputs
|
||||
std::vector<Node *> ops_; // all nodes or subnets
|
||||
int InferShape();
|
||||
void AddLearn(Node *node) { learnable_.insert(node); }
|
||||
void AssignLearn(std::set<Node *> &&learn) { learnable_ = learn; }
|
||||
|
||||
std::unique_ptr<schema::CNodeT> CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex);
|
||||
virtual int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode);
|
||||
std::unique_ptr<schema::TensorT> CreateTensor(std::string name, int type, int data_type,
|
||||
const std::vector<int32_t> dims, int format,
|
||||
const std::vector<uint8_t> &data);
|
||||
|
||||
private:
|
||||
int CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input = false);
|
||||
void FreeAllTensors(std::vector<Tensor *> *tensors);
|
||||
static int name_id;
|
||||
std::set<Node *> learnable_; // set of nodes with learnable parameters
|
||||
std::string name_;
|
||||
schema::PrimitiveType primitive_;
|
||||
std::shared_ptr<NodeImpl> impl_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_
|
|
@ -1,66 +0,0 @@
|
|||
|
||||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/ops_utils.h"
|
||||
#include "src/expression/param.h"
|
||||
#include "include/api/cfg.h"
|
||||
#include "src/expression/sequential.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void InputM::SetUp(const std::vector<int> &dims, TypeId data_type, int fmt) {
|
||||
expr()->SetSize(C0NUM);
|
||||
expr()->SetDims(dims);
|
||||
expr()->set_data_type(data_type);
|
||||
expr()->set_format(fmt);
|
||||
set_primitive(schema::PrimitiveType_NONE);
|
||||
}
|
||||
|
||||
InputM::InputM(const std::vector<int> &dims, TypeId data_type, int fmt) : Node() { SetUp(dims, data_type, fmt); }
|
||||
|
||||
InputM::InputM(const schema::Tensor *tensor) : Node() {
|
||||
std::vector<int> dims(tensor->dims()->size());
|
||||
(void)std::transform(tensor->dims()->begin(), tensor->dims()->end(), dims.begin(), [](int32_t x) { return x; });
|
||||
SetUp(dims, static_cast<TypeId>(tensor->dataType()), tensor->format());
|
||||
if (tensor->name()) set_name(tensor->name()->str());
|
||||
if (tensor->data() != nullptr) data_.Copy(tensor->data()->data(), tensor->data()->size());
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *Input(const std::vector<int> &dims, TypeId data_type, int fmt) {
|
||||
auto i = new (std::nothrow) InputM(dims, data_type, fmt);
|
||||
if (i == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate input expression ";
|
||||
return nullptr;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
Net *Sequential() {
|
||||
auto s = new (std::nothrow) mindspore::lite::Sequential();
|
||||
if (s == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate sequential expression";
|
||||
return nullptr;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
}; // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,69 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "include/api/net.h"
|
||||
#include "src/expression/cfg.h"
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class InputM : public Node {
|
||||
public:
|
||||
explicit InputM(const schema::Tensor *tensor);
|
||||
explicit InputM(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC);
|
||||
Param *data() override { return &data_; }
|
||||
|
||||
private:
|
||||
void SetUp(const std::vector<int> &dims, TypeId data_type, int fmt);
|
||||
Param data_;
|
||||
};
|
||||
namespace NN {
|
||||
Node *Conv2D(const ConvConfig &cfg);
|
||||
Node *Relu();
|
||||
Node *Dense(const DenseConfig &cfg);
|
||||
Node *Flatten();
|
||||
Node *Input(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC);
|
||||
Node *Add();
|
||||
Node *Sub();
|
||||
Node *Div();
|
||||
Node *Mul();
|
||||
Node *Neg();
|
||||
Node *SoftmaxCrossEntropy();
|
||||
Net *Sequential();
|
||||
Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg);
|
||||
|
||||
Node *Softmax(int axis = -1);
|
||||
Node *BatchNorm2D(int outp, float momentum = 0.1, float epsilon = 1e-5f);
|
||||
Node *Sigmoid();
|
||||
Node *DropOut(float ration = 0.5);
|
||||
Node *ReLU6();
|
||||
Node *Reshape(const std::vector<int> &shape);
|
||||
Node *ReduceMean(bool keep_dims, const std::vector<int> &dims);
|
||||
Node *ReduceSum(bool keep_dims, const std::vector<int> &dims);
|
||||
Node *Tile(const std::vector<int> &multiples);
|
||||
Node *MaxPool2D(const PoolingConfig &cfg);
|
||||
Node *AvgPool2D(const PoolingConfig &cfg);
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_
|
|
@ -1,133 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/expression/ops/activation.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ActM::ActM(schema::ActivationType type) : Node() {
|
||||
auto op_param = malloc(sizeof(ActivationParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_Activation);
|
||||
ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(opParam_.get());
|
||||
act_param->type_ = type;
|
||||
act_param->alpha_ = 0.f;
|
||||
act_param->min_val_ = 0.f;
|
||||
act_param->max_val_ = 0.f;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ActM::Grad(EXPR *yt) {
|
||||
auto actGrad = new (std::nothrow) ActGradM(this);
|
||||
if (actGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate activation grad node";
|
||||
return {};
|
||||
}
|
||||
PushOp(actGrad);
|
||||
auto param = reinterpret_cast<ActivationParameter *>(actGrad->OpParam());
|
||||
EXPR *ag = nullptr;
|
||||
actGrad->expr()->SetSize(C2NUM);
|
||||
if ((param->type_ == schema::ActivationType_SIGMOID) || (param->type_ == schema::ActivationType_TANH)) {
|
||||
ag = (*actGrad)({output(0), yt}).front();
|
||||
} else if ((param->type_ == schema::ActivationType_HSWISH) || (param->type_ == schema::ActivationType_HSIGMOID) ||
|
||||
(param->type_ == schema::ActivationType_RELU6)) {
|
||||
ag = (*actGrad)({yt, input(0)}).front();
|
||||
} else if (param->type_ == schema::ActivationType_GELU) {
|
||||
actGrad->expr()->SetSize(C3NUM);
|
||||
ag = (*actGrad)({yt, input(0), output(0)}).front();
|
||||
} else {
|
||||
ag = (*actGrad)({yt, output(0)}).front();
|
||||
}
|
||||
std::vector<EXPR *> res = {ag};
|
||||
return res;
|
||||
}
|
||||
|
||||
int ActM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::ActivationT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate activation primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_);
|
||||
prim->alpha = act_param->alpha_;
|
||||
prim->min_val = act_param->min_val_;
|
||||
prim->max_val = act_param->max_val_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Activation, ReturnNode<ActM>);
|
||||
|
||||
ActGradM::ActGradM(Node *node) {
|
||||
CloneOpParam<ActivationParameter>(node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_ActivationGrad);
|
||||
set_name(node->name() + "/" + kGradName + "/actGrad");
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ActGradM::Grad(EXPR *yt) { return {}; }
|
||||
|
||||
int ActGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::ActivationGradT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate activation grad primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_);
|
||||
prim->alpha = act_param->alpha_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg regGrad(schema::PrimitiveType_ActivationGrad, ReturnNode<ActGradM>);
|
||||
namespace NN {
|
||||
Node *ReLU6() {
|
||||
auto r = new (std::nothrow) ActM(schema::ActivationType_RELU6);
|
||||
if (r == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate relu6";
|
||||
return nullptr;
|
||||
}
|
||||
r->set_name(Node::UniqueName("ReLU6"));
|
||||
return r;
|
||||
}
|
||||
Node *Sigmoid() {
|
||||
auto s = new (std::nothrow) ActM(schema::ActivationType_SIGMOID);
|
||||
if (s == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate sigmoid";
|
||||
return nullptr;
|
||||
}
|
||||
s->set_name(Node::UniqueName("Sigmoid"));
|
||||
return s;
|
||||
}
|
||||
Node *Relu() {
|
||||
auto r = new (std::nothrow) ActM(schema::ActivationType_RELU);
|
||||
if (r == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate relu";
|
||||
return nullptr;
|
||||
}
|
||||
r->set_name(r->UniqueName("Relu"));
|
||||
return r;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ActM : public Node {
|
||||
public:
|
||||
ActM() = default;
|
||||
explicit ActM(schema::ActivationType type);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class ActGradM : public Node {
|
||||
public:
|
||||
ActGradM() : Node() {} // for Import
|
||||
explicit ActGradM(Node *act); // for Grad
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_
|
|
@ -1,142 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/adam.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/ops/assign.h"
|
||||
#include "src/expression/ops/arithmetic.h"
|
||||
#include "nnacl/fp32_grad/optimizer.h"
|
||||
#include "include/api/net.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace NN {
|
||||
Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg) {
|
||||
auto lite_node = lite::NN::Adam(std::move(learn->set_), cfg);
|
||||
return NodeImpl::Connect(lite_node);
|
||||
}
|
||||
} // namespace NN
|
||||
|
||||
namespace lite {
|
||||
std::vector<EXPR *> AdamM::Clone(EXPR *grad, EXPR *weight) {
|
||||
auto adam = new (std::nothrow) AdamM();
|
||||
if (adam == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate adam";
|
||||
return {};
|
||||
}
|
||||
adam->set_name("optimizer-Adam");
|
||||
adam->CloneOpParam<AdamParameter>(OpParam());
|
||||
adam->update_name(weight->node()->name());
|
||||
adam->set_primitive(primitive());
|
||||
adam->expr()->SetSize(C10NUM);
|
||||
// setup weight and momentum
|
||||
adam->expr()->set_params(C0NUM, weight);
|
||||
auto dims = grad->dims();
|
||||
auto m = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "m");
|
||||
adam->expr()->set_params(C1NUM, m);
|
||||
auto v = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "v");
|
||||
adam->expr()->set_params(C2NUM, v);
|
||||
// copy parameters
|
||||
for (int i = C3NUM; i < C9NUM; i++) {
|
||||
adam->expr()->set_params(i, this->input(i));
|
||||
}
|
||||
adam->expr()->set_params(C9NUM, grad);
|
||||
return (*adam)(adam->inputs());
|
||||
}
|
||||
|
||||
AdamM::AdamM(std::set<Node *> &&learn, const AdamConfig &cfg) {
|
||||
auto op_param = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
|
||||
return;
|
||||
}
|
||||
AssignLearn(std::move(learn));
|
||||
memset(op_param, 0, sizeof(AdamParameter));
|
||||
op_param->use_nesterov_ = cfg.use_nesterov_;
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_Adam);
|
||||
set_name("optimizer-Adam");
|
||||
// Adam Network
|
||||
expr()->SetSize(C10NUM);
|
||||
auto assign1 = new (std::nothrow) AssignM(0);
|
||||
if (assign1 == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate assign";
|
||||
return;
|
||||
}
|
||||
PushOp(assign1);
|
||||
auto assign2 = new (std::nothrow) AssignM(0);
|
||||
if (assign2 == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate assign";
|
||||
return;
|
||||
}
|
||||
PushOp(assign2);
|
||||
auto mul1 = NN::Mul();
|
||||
if (mul1 == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate mul";
|
||||
return;
|
||||
}
|
||||
PushOp(mul1);
|
||||
auto mul2 = NN::Mul();
|
||||
if (mul2 == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate mul";
|
||||
return;
|
||||
}
|
||||
PushOp(mul2);
|
||||
auto tmp = 1.0f;
|
||||
mul1->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-power", &tmp);
|
||||
mul1->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-data", &cfg.beta1_);
|
||||
auto o1 = (*mul1)({});
|
||||
assign1_ = (*assign1)({mul1->input(0), o1.front()}).front();
|
||||
mul2->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-power", &tmp);
|
||||
mul2->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-data", &cfg.beta2_);
|
||||
auto o2 = (*mul2)({});
|
||||
assign2_ = (*assign2)({mul2->input(0), o2.front()}).front();
|
||||
expr()->set_params(C3NUM, o1.front());
|
||||
expr()->set_params(C4NUM, o2.front());
|
||||
CreateConstTensor(C5NUM, {1}, kNumberTypeFloat32, KHWC, "learning-rate", &cfg.learning_rate_);
|
||||
CreateConstTensor(C6NUM, {1}, kNumberTypeFloat32, KHWC, "beta1", &cfg.beta1_);
|
||||
CreateConstTensor(C7NUM, {1}, kNumberTypeFloat32, KHWC, "beta2", &cfg.beta2_);
|
||||
CreateConstTensor(C8NUM, {1}, kNumberTypeFloat32, KHWC, "epsilon", &cfg.eps_);
|
||||
}
|
||||
|
||||
int AdamM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const AdamParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::AdamT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate " << cnode->name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->use_nesterov = param->use_nesterov_;
|
||||
prim->use_locking = false;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg) {
|
||||
auto a = new (std::nothrow) AdamM(std::move(learn), cfg);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate adam";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,46 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "include/api/net.h"
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AdamM : public Node {
|
||||
public:
|
||||
AdamM() = default;
|
||||
AdamM(std::set<Node *> &&learn, const AdamConfig &cfg);
|
||||
std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) override;
|
||||
void AddNetOutput(std::vector<EXPR *> *output) override {
|
||||
output->push_back(assign1_);
|
||||
output->push_back(assign2_);
|
||||
}
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
|
||||
private:
|
||||
EXPR *assign1_{nullptr};
|
||||
EXPR *assign2_{nullptr};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/addn.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
AddN::AddN(int dummy) {
|
||||
auto op_param = calloc(1, sizeof(OpParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter ";
|
||||
return;
|
||||
}
|
||||
set_name(UniqueName("addN"));
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_AddN);
|
||||
}
|
||||
|
||||
int AddN::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::AddNT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,34 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_
|
||||
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AddN : public Node {
|
||||
public:
|
||||
explicit AddN(int dummy);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_
|
|
@ -1,223 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/arithmetic.h"
|
||||
#include <memory>
|
||||
#include "src/expression/ops/reduce.h"
|
||||
#include "src/expression/ops/reshape.h"
|
||||
#include "src/expression/ops_utils.h"
|
||||
#include "src/expression/ops/arithmetic_self.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "nnacl/arithmetic_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
// Common Arithmetic Functionality
|
||||
ArithmeticM::ArithmeticM(schema::PrimitiveType type) : Node() {
|
||||
auto op_param = malloc(sizeof(ArithmeticParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ActivationParameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(op_param);
|
||||
expr()->SetSize(C2NUM);
|
||||
set_primitive(type);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ArithmeticM::binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy) {
|
||||
auto shape_of_x = x->dims();
|
||||
auto shape_of_y = y->dims();
|
||||
auto reduce_dx = dx;
|
||||
auto reduce_dy = dy;
|
||||
auto rx = (BroadcastGradientArgs(shape_of_x, shape_of_y))();
|
||||
if (rx[0].size()) {
|
||||
auto reduce_sum = NN::ReduceSum(false, rx[0]);
|
||||
PushOp(reduce_sum);
|
||||
reduce_dx = (*reduce_sum)({reduce_dx}).front();
|
||||
auto reshape = NN::Reshape(shape_of_x);
|
||||
PushOp(reshape);
|
||||
reduce_dx = (*reshape)({reduce_dx}).front();
|
||||
}
|
||||
if (rx[1].size()) {
|
||||
auto reduce_sum = NN::ReduceSum(false, rx[1]);
|
||||
PushOp(reduce_sum);
|
||||
reduce_dy = (*reduce_sum)({reduce_dy}).front();
|
||||
auto reshape = NN::Reshape(shape_of_y);
|
||||
PushOp(reshape);
|
||||
reduce_dy = (*reshape)({reduce_dy}).front();
|
||||
}
|
||||
std::vector<EXPR *> out = {reduce_dx, reduce_dy};
|
||||
return out;
|
||||
}
|
||||
|
||||
// Add Op
|
||||
AddM::AddM(int dummy) : ArithmeticM(schema::PrimitiveType_AddFusion) { set_name(UniqueName("Add")); }
|
||||
|
||||
std::vector<EXPR *> AddM::Grad(EXPR *yt) { return binop_grad_common(input(0), input(1), yt, yt); }
|
||||
|
||||
int AddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::AddFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate prim";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg AddReg(schema::PrimitiveType_AddFusion, ReturnNode<AddM>);
|
||||
|
||||
// Div op
|
||||
DivM::DivM(int dummy) : ArithmeticM(schema::PrimitiveType_RealDiv) { set_name(UniqueName("RealDiv")); }
|
||||
std::vector<EXPR *> DivM::Grad(EXPR *yt) {
|
||||
auto x = input(0);
|
||||
auto y = input(1);
|
||||
auto o = output(0);
|
||||
auto div_op = NN::Div();
|
||||
if (div_op == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate div_op";
|
||||
return {};
|
||||
}
|
||||
PushOp(div_op);
|
||||
auto neg_op = NN::Neg();
|
||||
if (neg_op == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate neg_op";
|
||||
return {};
|
||||
}
|
||||
PushOp(neg_op);
|
||||
auto mul_op = NN::Mul();
|
||||
if (mul_op == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate mul_op";
|
||||
return {};
|
||||
}
|
||||
PushOp(mul_op);
|
||||
auto bc_x = (*div_op)({yt, y}).front();
|
||||
auto bc_y = (*neg_op)((*mul_op)({bc_x, o})).front();
|
||||
return binop_grad_common(x, y, bc_x, bc_y);
|
||||
}
|
||||
int DivM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::RealDivT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
static ImportReg DivReg(schema::PrimitiveType_DivFusion, ReturnNode<DivM>);
|
||||
|
||||
// Mul op
|
||||
MulM::MulM(int dummy) : ArithmeticM(schema::PrimitiveType_MulFusion) { set_name(UniqueName("Mul")); }
|
||||
|
||||
std::vector<EXPR *> MulM::Grad(EXPR *yt) {
|
||||
auto mul_dx = NN::Mul();
|
||||
if (mul_dx == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate mul dx";
|
||||
return {};
|
||||
}
|
||||
PushOp(mul_dx);
|
||||
auto mul_dy = NN::Mul();
|
||||
if (mul_dy == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate mul_dy";
|
||||
return {};
|
||||
}
|
||||
PushOp(mul_dy);
|
||||
auto x = input(0);
|
||||
auto y = input(1);
|
||||
auto bc_dx = (*mul_dx)({y, yt}).front();
|
||||
auto bc_dy = (*mul_dy)({x, yt}).front();
|
||||
return binop_grad_common(x, y, bc_dx, bc_dy);
|
||||
}
|
||||
|
||||
int MulM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::MulFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate prim";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
static ImportReg MulReg(schema::PrimitiveType_MulFusion, ReturnNode<MulM>);
|
||||
|
||||
// Sub op
|
||||
SubM::SubM(int dummy) : ArithmeticM(schema::PrimitiveType_SubFusion) { set_name(UniqueName("Sub")); }
|
||||
|
||||
std::vector<EXPR *> SubM::Grad(EXPR *yt) {
|
||||
auto x = input(0);
|
||||
auto y = input(1);
|
||||
auto neg = NN::Neg();
|
||||
if (neg == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate neg";
|
||||
return {};
|
||||
}
|
||||
PushOp(neg);
|
||||
auto neg_grad = (*neg)({yt}).front();
|
||||
return binop_grad_common(x, y, yt, neg_grad);
|
||||
}
|
||||
int SubM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::SubFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate prim";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
static ImportReg SubReg(schema::PrimitiveType_SubFusion, ReturnNode<SubM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Add() {
|
||||
auto a = new (std::nothrow) AddM(0);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate a";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
Node *Sub() {
|
||||
auto a = new (std::nothrow) SubM(0);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate a";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
Node *Mul() {
|
||||
auto a = new (std::nothrow) MulM(0);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate a";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
Node *Div() {
|
||||
auto a = new (std::nothrow) DivM(0);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate a";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArithmeticM : public Node {
|
||||
public:
|
||||
ArithmeticM() = default;
|
||||
explicit ArithmeticM(schema::PrimitiveType type);
|
||||
|
||||
protected:
|
||||
std::vector<EXPR *> binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy);
|
||||
};
|
||||
|
||||
class AddM : public ArithmeticM {
|
||||
public:
|
||||
AddM() = default;
|
||||
explicit AddM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class DivM : public ArithmeticM {
|
||||
public:
|
||||
DivM() = default;
|
||||
explicit DivM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class MulM : public ArithmeticM {
|
||||
public:
|
||||
MulM() = default;
|
||||
explicit MulM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class SubM : public ArithmeticM {
|
||||
public:
|
||||
SubM() = default;
|
||||
explicit SubM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
namespace NN {
|
||||
Node *Add();
|
||||
Node *Sub();
|
||||
Node *Mul();
|
||||
Node *Div();
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_
|
|
@ -1,72 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/arithmetic_self.h"
|
||||
#include <memory>
|
||||
#include "src/expression/ops_utils.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "nnacl/arithmetic_self_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
// Common Arithmetic Self Functionality
|
||||
ArithmeticSelfM::ArithmeticSelfM(schema::PrimitiveType type) : Node() {
|
||||
auto op_param = malloc(sizeof(ArithmeticSelfParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ArithmeticSelfParameter";
|
||||
return;
|
||||
}
|
||||
memset(op_param, 0, sizeof(ArithmeticSelfParameter));
|
||||
SetOpParam(op_param);
|
||||
expr()->SetSize(C1NUM);
|
||||
set_primitive(type);
|
||||
}
|
||||
|
||||
// NEG OP
|
||||
NegM::NegM(int dummy) : ArithmeticSelfM(schema::PrimitiveType_NegGrad) { set_name(UniqueName("Neg")); }
|
||||
|
||||
std::vector<EXPR *> NegM::Grad(EXPR *yt) {
|
||||
auto grad_neg = new (std::nothrow) NegM(0);
|
||||
if (grad_neg == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate neg gradient";
|
||||
return {};
|
||||
}
|
||||
return (*grad_neg)({yt});
|
||||
}
|
||||
|
||||
int NegM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::NegT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *Neg() {
|
||||
auto a = new (std::nothrow) NegM(0);
|
||||
if (a == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate neg node";
|
||||
return nullptr;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,46 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ArithmeticSelfM : public Node {
|
||||
public:
|
||||
explicit ArithmeticSelfM(schema::PrimitiveType type);
|
||||
|
||||
protected:
|
||||
};
|
||||
|
||||
class NegM : public ArithmeticSelfM {
|
||||
public:
|
||||
explicit NegM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *Neg();
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_
|
|
@ -1,60 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/assign.h"
|
||||
#include <memory>
|
||||
#include "nnacl/reshape_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
AssignM::AssignM(int dummy) {
|
||||
auto op_param = calloc(1, sizeof(OpParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ReshapeParameter";
|
||||
return;
|
||||
}
|
||||
expr()->SetSize(C2NUM);
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_Assign);
|
||||
set_name(UniqueName("Assign"));
|
||||
}
|
||||
|
||||
int AssignM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::AssignT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<AssignM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Assign() {
|
||||
auto node = new (std::nothrow) AssignM(0);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate node";
|
||||
return nullptr;
|
||||
}
|
||||
node->set_name(Node::UniqueName("Assign"));
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,35 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AssignM : public Node {
|
||||
public:
|
||||
AssignM() = default;
|
||||
explicit AssignM(int dummy);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_
|
|
@ -1,135 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/batchnorm.h"
|
||||
#include <memory>
|
||||
#include "nnacl/batchnorm_parameter.h"
|
||||
#include "nnacl/fp32_grad/batch_norm_grad.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
BatchNorm2dM::BatchNorm2dM(int outp, float momentum, float epsilon) {
|
||||
constexpr int bn_inputs = 5;
|
||||
constexpr int bn_outputs = 5;
|
||||
|
||||
auto op_param = calloc(1, sizeof(BatchNormParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate BatchNormParameter";
|
||||
return;
|
||||
}
|
||||
expr()->SetSize(bn_inputs);
|
||||
set_name(UniqueName("BatchNorm2D"));
|
||||
auto bn_param = reinterpret_cast<BatchNormParameter *>(op_param);
|
||||
channel_ = outp;
|
||||
bn_param->momentum_ = momentum;
|
||||
bn_param->epsilon_ = epsilon;
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_FusedBatchNorm);
|
||||
std::vector<int> dims = {outp};
|
||||
auto scale = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "scale");
|
||||
expr()->set_params(C1NUM, scale);
|
||||
auto offset = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "offset");
|
||||
expr()->set_params(C2NUM, offset);
|
||||
auto mean = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "mean");
|
||||
expr()->set_params(C3NUM, mean);
|
||||
auto var = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "var");
|
||||
expr()->set_params(C4NUM, var);
|
||||
SetOutputs(bn_outputs);
|
||||
SetLearn();
|
||||
}
|
||||
|
||||
BatchNorm2dGradM::BatchNorm2dGradM(BatchNorm2dM *bn_node) : Node() {
|
||||
auto op_param = calloc(1, sizeof(BNGradParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate BNGradParameter";
|
||||
return;
|
||||
}
|
||||
expr()->SetSize(C6NUM);
|
||||
set_name(bn_node->name() + "/" + kGradName + "/bnGrad");
|
||||
auto bn_grad_param = reinterpret_cast<BNGradParameter *>(op_param);
|
||||
auto bn_param = reinterpret_cast<BatchNormParameter *>(bn_node->OpParam());
|
||||
bn_param->is_training_ = true;
|
||||
bn_grad_param->epsilon_ = bn_param->epsilon_;
|
||||
bn_grad_param->is_training_ = true;
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_BatchNormGrad);
|
||||
EXPR e(this);
|
||||
e.SetSize(0);
|
||||
// Dgamma
|
||||
expr_.emplace_back(e);
|
||||
// Doffset
|
||||
expr_.emplace_back(e);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> BatchNorm2dM::Grad(EXPR *yt) {
|
||||
auto bn_grad_node = new (std::nothrow) BatchNorm2dGradM(this);
|
||||
if (bn_grad_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate batchnorm grad";
|
||||
return {};
|
||||
}
|
||||
PushOp(bn_grad_node);
|
||||
auto bn_grad = (*bn_grad_node)({yt, input(0), output(1), output(3), output(4), output(2)});
|
||||
return bn_grad;
|
||||
}
|
||||
|
||||
int BatchNorm2dM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto bn_param = reinterpret_cast<const BatchNormParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::FusedBatchNormT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->epsilon = bn_param->epsilon_;
|
||||
prim->momentum = bn_param->momentum_;
|
||||
prim->mode = (bn_param->is_training_ == false) ? 0 : 1;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void BatchNorm2dM::SetLearn() {
|
||||
AddLearn(input(C1NUM)->node());
|
||||
AddLearn(input(C2NUM)->node());
|
||||
}
|
||||
|
||||
int BatchNorm2dGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const BNGradParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::BatchNormGradT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->epsilon = param->epsilon_;
|
||||
prim->is_training = param->is_training_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_FusedBatchNorm, ReturnNode<BatchNorm2dM>);
|
||||
namespace NN {
|
||||
Node *BatchNorm2D(int outp, float momentum, float epsilon) {
|
||||
auto node = new (std::nothrow) BatchNorm2dM(outp, momentum, epsilon);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate node";
|
||||
return nullptr;
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BatchNorm2dM : public Node {
|
||||
public:
|
||||
BatchNorm2dM() = default;
|
||||
BatchNorm2dM(int outp, float momentum, float epsilon);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
void SetLearn() override;
|
||||
int channel_;
|
||||
};
|
||||
|
||||
class BatchNorm2dGradM : public Node {
|
||||
public:
|
||||
explicit BatchNorm2dGradM(BatchNorm2dM *bn_node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/biasadd.h"
|
||||
#include "src/expression/ops/transpose.h"
|
||||
#include "nnacl/arithmetic_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
BiasAddM::BiasAddM(Format data_format) {
|
||||
auto op_param = calloc(1, sizeof(ArithmeticParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ConvParameter";
|
||||
return;
|
||||
}
|
||||
auto bias_param = reinterpret_cast<ArithmeticParameter *>(op_param);
|
||||
SetOpParam(bias_param);
|
||||
set_primitive(schema::PrimitiveType_BiasAdd);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> BiasAddM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto x = Node::construct(inputs);
|
||||
AddLearn(inputs.at(C1NUM)->node());
|
||||
return x;
|
||||
}
|
||||
|
||||
void BiasAddM::SetLearn() { AddLearn(input(C1NUM)->node()); }
|
||||
|
||||
std::vector<EXPR *> BiasAddM::Grad(EXPR *yt) {
|
||||
auto in = yt;
|
||||
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
|
||||
in = TransposeM::TransposeCHW2HWC(yt);
|
||||
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
|
||||
PushOp(in->node());
|
||||
}
|
||||
auto grad_node = new (std::nothrow) BiasAddGradM(*this);
|
||||
if (grad_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannon allocate Bias Grad";
|
||||
return {};
|
||||
}
|
||||
PushOp(grad_node);
|
||||
auto bias_grad = (*grad_node)({in});
|
||||
return {in, bias_grad.front()};
|
||||
}
|
||||
int BiasAddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::BiasAddT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot allocate prim";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->format = static_cast<schema::Format>(KHWC);
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
BiasAddGradM::BiasAddGradM(const BiasAddM &bias) {
|
||||
auto op_param = calloc(1, sizeof(OpParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate op_param";
|
||||
}
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_BiasAddGrad);
|
||||
}
|
||||
|
||||
int BiasAddGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::BiasAddGradT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_BiasAdd, ReturnNode<BiasAddM>);
|
||||
|
||||
namespace NN {}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/net.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BiasAddM : public Node {
|
||||
public:
|
||||
BiasAddM() = default;
|
||||
explicit BiasAddM(Format data_format);
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
std::vector<EXPR *> Grad(EXPR *yt) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
void SetLearn() override;
|
||||
};
|
||||
|
||||
class BiasAddGradM : public Node {
|
||||
public:
|
||||
explicit BiasAddGradM(const BiasAddM &bias);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_
|
|
@ -1,241 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/conv.h"
|
||||
#include <memory>
|
||||
#include "src/expression/ops/biasadd.h"
|
||||
#include "src/expression/ops/depend.h"
|
||||
#include "src/expression/ops/transpose.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ConvM::ConvM(const ConvConfig &cfg) : Node() {
|
||||
auto op_param = calloc(1, sizeof(ConvParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ConvParameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(op_param);
|
||||
ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(OpParam());
|
||||
conv_param->input_channel_ = cfg.in_channel_;
|
||||
conv_param->output_channel_ = cfg.out_channel_;
|
||||
conv_param->kernel_h_ = cfg.kernel_size_[0];
|
||||
conv_param->kernel_w_ = cfg.kernel_size_[1];
|
||||
conv_param->stride_h_ = cfg.stride_[0];
|
||||
conv_param->stride_w_ = cfg.stride_[1];
|
||||
auto pad_mode = GetMode(cfg.pad_mode_);
|
||||
if (pad_mode == -1) {
|
||||
MS_LOG(ERROR) << "bad pad mode";
|
||||
return;
|
||||
}
|
||||
conv_param->pad_mode_ = static_cast<PadType>(pad_mode);
|
||||
conv_param->pad_u_ = cfg.padding_[C0NUM];
|
||||
conv_param->pad_d_ = cfg.padding_[C1NUM];
|
||||
conv_param->pad_l_ = cfg.padding_[C2NUM];
|
||||
conv_param->pad_r_ = cfg.padding_[C3NUM];
|
||||
conv_param->dilation_h_ = cfg.dilation_[C0NUM];
|
||||
conv_param->dilation_w_ = cfg.dilation_[C1NUM];
|
||||
conv_param->group_ = cfg.group_;
|
||||
conv_param->out_format_ = NHWC;
|
||||
conv_param->act_type_ = ActType_No;
|
||||
expr()->SetSize(C2NUM);
|
||||
set_primitive(schema::PrimitiveType_Conv2DFusion);
|
||||
set_name(UniqueName("Conv"));
|
||||
Param::Mode mode = Param::String2Enum(cfg.weight_init_);
|
||||
std::vector<int> dims = {conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
conv_param->input_channel_ / conv_param->group_};
|
||||
auto w = CreateWeights(dims, kNumberTypeFloat32, KHWC, mode, "weights");
|
||||
expr()->set_params(C1NUM, w);
|
||||
if (cfg.has_bias) {
|
||||
bias_ = new (std::nothrow) BiasAddM(KHWC);
|
||||
if (bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate bias";
|
||||
return;
|
||||
}
|
||||
bias_->update_name(name());
|
||||
std::vector<int> dim_bias = {conv_param->output_channel_};
|
||||
wbias_ = CreateWeights(dim_bias, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "weights");
|
||||
AddLearn(wbias_->node());
|
||||
PushOp(bias_);
|
||||
}
|
||||
SetLearn();
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ConvM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto in = inputs;
|
||||
auto x = in.front();
|
||||
if (x->format() != NHWC && x->dims().size() == C4NUM) {
|
||||
x = TransposeM::TransposeCHW2HWC(x);
|
||||
x->node()->set_name(name() + "/" + x->node()->name());
|
||||
PushOp(x->node());
|
||||
in.at(0) = x;
|
||||
}
|
||||
auto y = Node::construct(in);
|
||||
if (bias_ != nullptr) {
|
||||
y = (*bias_)({y.front(), wbias_});
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
void ConvM::SetLearn() { AddLearn(input(C1NUM)->node()); }
|
||||
|
||||
int ConvM::GetMode(std::string mode) {
|
||||
const std::vector<std::string> list = {"pad", "same", "valid"};
|
||||
auto itr = std::find(list.begin(), list.end(), mode);
|
||||
if (itr == list.end()) {
|
||||
MS_LOG(ERROR) << "illegal mode" << mode;
|
||||
return -1;
|
||||
}
|
||||
return std::distance(list.begin(), itr);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ConvM::Grad(EXPR *yt) {
|
||||
// Generate Input Grad
|
||||
EXPR *in = yt;
|
||||
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
|
||||
in = TransposeM::TransposeCHW2HWC(yt);
|
||||
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
|
||||
PushOp(in->node());
|
||||
}
|
||||
auto inGrad = new (std::nothrow) ConvInputGradM(this);
|
||||
if (inGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate convolution input grad";
|
||||
return {};
|
||||
}
|
||||
PushOp(inGrad);
|
||||
auto ig = (*inGrad)({in, input(1), inGrad->input(2)});
|
||||
// Execution Control Flow !
|
||||
auto depend = NN::Depend();
|
||||
if (depend == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate depend";
|
||||
return {};
|
||||
}
|
||||
PushOp(depend);
|
||||
depend->update_name(name());
|
||||
auto de = (*depend)({inGrad->expr()});
|
||||
// Generate Filter Grad
|
||||
auto filterGrad = new (std::nothrow) ConvFilterGradM(this);
|
||||
if (filterGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate convolution filter grad";
|
||||
return {};
|
||||
}
|
||||
PushOp(filterGrad);
|
||||
filterGrad->update_name(name());
|
||||
auto fg = (*filterGrad)({in, input(0), filterGrad->input(2), de[0]});
|
||||
std::vector<EXPR *> res = {ig[0], fg[0]};
|
||||
return res;
|
||||
}
|
||||
|
||||
int ConvM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::Conv2DFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
|
||||
prim->format = static_cast<schema::Format>(conv_param->out_format_);
|
||||
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
|
||||
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
|
||||
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
|
||||
prim->out_channel = conv_param->output_channel_;
|
||||
prim->in_channel = conv_param->input_channel_;
|
||||
prim->group = conv_param->group_;
|
||||
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
|
||||
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
|
||||
prim->mode = 1;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ConvInputGradM::ConvInputGradM(ConvM *conv_node) : Node() {
|
||||
CloneOpParam<ConvParameter>(conv_node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_Conv2DBackpropInputFusion);
|
||||
set_name(kGradName + "/conv2DBackpropInput");
|
||||
expr()->SetSize(C3NUM);
|
||||
auto const x = conv_node->input(0);
|
||||
CreateConstTensor(C2NUM, {static_cast<int32_t>(x->dims().size())}, kNumberTypeInt32, KHWC, "shape", x->dims().data());
|
||||
}
|
||||
|
||||
int ConvInputGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::Conv2DBackpropInputFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
|
||||
prim->format = static_cast<schema::Format>(conv_param->out_format_);
|
||||
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
|
||||
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
|
||||
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
|
||||
prim->out_channel = conv_param->output_channel_;
|
||||
prim->in_channel = conv_param->input_channel_;
|
||||
prim->group = conv_param->group_;
|
||||
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
|
||||
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ConvFilterGradM::ConvFilterGradM(ConvM *conv_node) : Node() {
|
||||
CloneOpParam<ConvParameter>(conv_node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_Conv2DBackpropFilterFusion);
|
||||
set_name(kGradName + "/conv2DBackpropFilter");
|
||||
expr()->SetSize(C4NUM);
|
||||
auto w = conv_node->input(1);
|
||||
CreateConstTensor(C2NUM, {static_cast<int32_t>(w->dims().size())}, kNumberTypeInt32, KHWC, "shape", w->dims().data());
|
||||
}
|
||||
int ConvFilterGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::Conv2DBackpropFilterFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_);
|
||||
prim->format = static_cast<schema::Format>(conv_param->out_format_);
|
||||
prim->stride = {conv_param->stride_h_, conv_param->stride_w_};
|
||||
prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_};
|
||||
prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_};
|
||||
prim->out_channel = conv_param->output_channel_;
|
||||
prim->in_channel = conv_param->input_channel_;
|
||||
prim->group = conv_param->group_;
|
||||
prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_);
|
||||
prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_};
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Conv2DFusion, ReturnNode<ConvM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Conv2D(const ConvConfig &cfg) {
|
||||
auto c = new (std::nothrow) ConvM(cfg);
|
||||
if (c == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate Convolution object";
|
||||
return nullptr;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,58 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "src/expression/cfg.h"
|
||||
#include "src/expression/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ConvM : public Node {
|
||||
public:
|
||||
ConvM() = default;
|
||||
explicit ConvM(const ConvConfig &cfg);
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
Param *weight() override { return input(1)->node()->data(); }
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
void SetLearn() override;
|
||||
|
||||
private:
|
||||
int GetMode(std::string mode);
|
||||
Node *bias_ = nullptr;
|
||||
EXPR *wbias_ = nullptr;
|
||||
};
|
||||
|
||||
class ConvInputGradM : public Node {
|
||||
public:
|
||||
explicit ConvInputGradM(ConvM *conv_node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class ConvFilterGradM : public Node {
|
||||
public:
|
||||
explicit ConvFilterGradM(ConvM *conv_node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_
|
|
@ -1,151 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/dense.h"
|
||||
#include <memory>
|
||||
#include "include/api/cfg.h"
|
||||
#include "src/expression/ops/biasadd.h"
|
||||
#include "src/expression/ops/depend.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
DenseM::DenseM(const DenseConfig &cfg) : Node() {
|
||||
auto op_param = calloc(1, sizeof(MatMulParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate MatMulParameter";
|
||||
return;
|
||||
}
|
||||
set_name(UniqueName("Dense"));
|
||||
SetOpParam(op_param);
|
||||
expr()->SetSize(C2NUM);
|
||||
set_primitive(schema::PrimitiveType_MatMulFusion);
|
||||
auto param = reinterpret_cast<MatMulParameter *>(opParam_.get());
|
||||
param->row_ = cfg.out_channels_;
|
||||
param->col_ = cfg.in_channels_;
|
||||
param->a_transpose_ = false;
|
||||
param->b_transpose_ = true;
|
||||
std::vector<int> dims = {param->row_, param->col_};
|
||||
auto w = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::NORMAL, "weights");
|
||||
expr()->set_params(C1NUM, w);
|
||||
if (cfg.has_bias_) {
|
||||
wbias_ = CreateWeights({cfg.out_channels_}, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "bias_weights");
|
||||
bias_ = new (std::nothrow) BiasAddM(KHWC);
|
||||
if (bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate bias";
|
||||
return;
|
||||
}
|
||||
bias_->update_name(name());
|
||||
AddLearn(wbias_->node());
|
||||
PushOp(bias_);
|
||||
}
|
||||
SetLearn();
|
||||
}
|
||||
|
||||
std::vector<EXPR *> DenseM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto x = Node::construct(inputs);
|
||||
if (bias_ != nullptr) {
|
||||
x = (*bias_)({x.front(), wbias_});
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> DenseM::Grad(EXPR *yt) {
|
||||
auto src_param = reinterpret_cast<MatMulParameter *>(opParam_.get());
|
||||
bool ta = src_param->a_transpose_;
|
||||
bool tb = src_param->b_transpose_;
|
||||
|
||||
// dx grad op
|
||||
auto dxGrad = new (std::nothrow) DenseM();
|
||||
if (dxGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate dxGrad ";
|
||||
return {};
|
||||
}
|
||||
PushOp(dxGrad);
|
||||
dxGrad->CloneOpParam<MatMulParameter>(opParam_);
|
||||
dxGrad->set_primitive(schema::PrimitiveType_MatMulFusion);
|
||||
auto dxGradParam = reinterpret_cast<MatMulParameter *>(dxGrad->OpParam());
|
||||
dxGradParam->a_transpose_ = (ta && tb);
|
||||
dxGradParam->b_transpose_ = (ta || !tb);
|
||||
dxGrad->set_name(name() + kGradName + "/dxGrad");
|
||||
EXPR *dx = nullptr;
|
||||
if (ta) {
|
||||
dx = (*dxGrad)({input(1), yt}).front();
|
||||
} else {
|
||||
dx = (*dxGrad)({yt, input(1)}).front();
|
||||
}
|
||||
// Control execution flow
|
||||
auto depend = NN::Depend();
|
||||
if (depend == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate depend ";
|
||||
return {};
|
||||
}
|
||||
PushOp(depend);
|
||||
auto de = (*depend)({dxGrad->expr()}).front();
|
||||
|
||||
// dw grad op
|
||||
auto dwGrad = new (std::nothrow) DenseM();
|
||||
if (dwGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate dwGrad ";
|
||||
return {};
|
||||
}
|
||||
PushOp(dwGrad);
|
||||
dwGrad->CloneOpParam<MatMulParameter>(opParam_);
|
||||
dwGrad->set_primitive(schema::PrimitiveType_MatMulFusion);
|
||||
auto dwGradParam = reinterpret_cast<MatMulParameter *>(dwGrad->OpParam());
|
||||
dwGradParam->a_transpose_ = (!ta || tb);
|
||||
dwGradParam->b_transpose_ = ta && tb;
|
||||
dwGrad->set_name(name() + kGradName + "/dwGrad");
|
||||
EXPR *dw = nullptr;
|
||||
if (tb) {
|
||||
dw = (*dwGrad)({yt, input(0), de}).front();
|
||||
} else {
|
||||
dw = (*dwGrad)({input(0), yt, de}).front();
|
||||
}
|
||||
return {dx, dw};
|
||||
}
|
||||
int DenseM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto dense_param = reinterpret_cast<const MatMulParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::MatMulFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->transpose_a = dense_param->a_transpose_;
|
||||
prim->transpose_b = dense_param->b_transpose_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void DenseM::SetLearn() { AddLearn(input(C1NUM)->node()); }
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_MatMulFusion, ReturnNode<DenseM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Dense(const DenseConfig &cfg) {
|
||||
auto l = new (std::nothrow) DenseM(cfg);
|
||||
if (l == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate Dense object";
|
||||
}
|
||||
return l;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "src/expression/cfg.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class DenseM : public Node {
|
||||
public:
|
||||
DenseM() = default;
|
||||
explicit DenseM(const DenseConfig &cfg);
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
void SetLearn() override;
|
||||
|
||||
private:
|
||||
Param *weight() override { return input(1)->node()->data(); }
|
||||
Node *bias_{nullptr};
|
||||
EXPR *wbias_{nullptr};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_
|
|
@ -1,43 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/depend.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
DependM::DependM() : Node() {
|
||||
auto param = calloc(1, sizeof(OpParameter));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(param);
|
||||
set_primitive(schema::PrimitiveType_Depend);
|
||||
set_name(UniqueName("Depend"));
|
||||
}
|
||||
namespace NN {
|
||||
Node *Depend() {
|
||||
auto d = new (std::nothrow) DependM();
|
||||
if (d == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate depend object";
|
||||
return nullptr;
|
||||
}
|
||||
return d;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/expression/node.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class DependM : public Node {
|
||||
public:
|
||||
DependM();
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *Depend();
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_
|
|
@ -1,91 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/dropout.h"
|
||||
#include <vector>
|
||||
#include "nnacl/fp32_grad/dropout_parameter.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
DropOutM::DropOutM(float ratio) {
|
||||
auto param = reinterpret_cast<DropoutParameter *>(calloc(1, sizeof(DropoutParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
param->ratio_ = ratio;
|
||||
SetOpParam(param);
|
||||
set_primitive(schema::PrimitiveType_Dropout);
|
||||
set_name(UniqueName("DropOut"));
|
||||
}
|
||||
|
||||
std::vector<EXPR *> DropOutM::Grad(EXPR *yt) {
|
||||
auto inGrad = new (std::nothrow) DropOutGradM(this);
|
||||
if (inGrad == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate drop grad";
|
||||
return {};
|
||||
}
|
||||
return (*inGrad)({yt, expr()});
|
||||
}
|
||||
|
||||
int DropOutM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const DropoutParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::DropoutT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->keep_prob = param->ratio_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
DropOutGradM::DropOutGradM(DropOutM *node) {
|
||||
CloneOpParam<DropoutParameter>(node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_DropoutGrad);
|
||||
set_name(kGradName + "/DropOutGrad");
|
||||
}
|
||||
|
||||
int DropOutGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const DropoutParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::DropoutGradT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->keep_prob = param->ratio_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Dropout, ReturnNode<DropOutM>);
|
||||
|
||||
namespace NN {
|
||||
Node *DropOut(float ratio) {
|
||||
auto node = new (std::nothrow) DropOutM(ratio);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate dropout node";
|
||||
return nullptr;
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class DropOutM : public Node {
|
||||
public:
|
||||
DropOutM() = default;
|
||||
explicit DropOutM(float ratio);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class DropOutGradM : public Node {
|
||||
public:
|
||||
explicit DropOutGradM(DropOutM *node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_
|
|
@ -1,71 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/flatten.h"
|
||||
#include <vector>
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
FlattenM::FlattenM(int dummy) {
|
||||
auto param = reinterpret_cast<OpParameter *>(calloc(C1NUM, sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(param);
|
||||
set_primitive(schema::PrimitiveType_Flatten);
|
||||
set_name(UniqueName("Flatten"));
|
||||
}
|
||||
|
||||
std::vector<EXPR *> FlattenM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto in = inputs;
|
||||
auto y = Node::construct(in);
|
||||
return y;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> FlattenM::Grad(EXPR *yt) {
|
||||
auto shape_of_x = input(0)->dims();
|
||||
auto reshape = NN::Reshape(shape_of_x);
|
||||
PushOp(reshape);
|
||||
return (*reshape)({yt});
|
||||
}
|
||||
|
||||
int FlattenM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::DropoutT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Flatten, ReturnNode<FlattenM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Flatten() {
|
||||
auto node = new (std::nothrow) FlattenM(0);
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,36 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "src/expression/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class FlattenM : public Node {
|
||||
public:
|
||||
FlattenM() = default;
|
||||
explicit FlattenM(int dummy);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_
|
|
@ -1,215 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/pooling.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/expression/ops/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PoolingM::PoolingM(const PoolingConfig &cfg) : Node() {
|
||||
auto op_param = calloc(1, sizeof(PoolingParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate PoolingParameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(op_param);
|
||||
PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(OpParam());
|
||||
|
||||
pool_param->window_h_ = cfg.kernel_size_[0];
|
||||
pool_param->window_w_ = cfg.kernel_size_[1];
|
||||
pool_param->stride_h_ = cfg.stride_[0];
|
||||
pool_param->stride_w_ = cfg.stride_[1];
|
||||
auto pad_mode = GetMode(cfg.pad_mode_);
|
||||
if (pad_mode == -1) {
|
||||
MS_LOG(ERROR) << "bad pad mode";
|
||||
return;
|
||||
}
|
||||
pool_param->pad_mode_ = static_cast<PadType>(pad_mode + Pad_pad);
|
||||
pool_param->round_type_ = RoundType_Floor;
|
||||
pool_param->act_type_ = ActType_No;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> PoolingM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto in = inputs;
|
||||
auto x = in.front();
|
||||
if (x->format() != NHWC && x->dims().size() == C4NUM) {
|
||||
x = TransposeM::TransposeCHW2HWC(x);
|
||||
x->node()->set_name(name() + "/" + x->node()->name());
|
||||
PushOp(x->node());
|
||||
in.at(0) = x;
|
||||
}
|
||||
auto y = Node::construct(in);
|
||||
return y;
|
||||
}
|
||||
|
||||
int PoolingM::GetMode(std::string mode) {
|
||||
const std::vector<std::string> list = {"same", "valid"};
|
||||
auto itr = std::find(list.begin(), list.end(), mode);
|
||||
if (itr == list.end()) {
|
||||
MS_LOG(ERROR) << "illegal mode" << mode;
|
||||
return -1;
|
||||
}
|
||||
return std::distance(list.begin(), itr);
|
||||
}
|
||||
|
||||
void PoolingM::UpdateRoundMode(const PoolingParameter *param, schema::RoundMode *round_mode) {
|
||||
switch (param->round_type_) {
|
||||
case RoundType_Floor:
|
||||
*round_mode = schema::RoundMode_FLOOR;
|
||||
break;
|
||||
case RoundType_Ceil:
|
||||
*round_mode = schema::RoundMode_CEIL;
|
||||
break;
|
||||
default:
|
||||
*round_mode = schema::RoundMode_FLOOR;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int PoolingM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const PoolingParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) T;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->kernel_size = {param->window_h_, param->window_w_};
|
||||
prim->strides = {param->stride_h_, param->stride_w_};
|
||||
prim->pad = {param->pad_u_, param->pad_d_, param->pad_l_, param->pad_r_};
|
||||
prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_);
|
||||
UpdateRoundMode(param, &prim->round_mode);
|
||||
prim->global = param->global_;
|
||||
prim->activation_type = schema::ActivationType_NO_ACTIVATION;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int PoolingM::UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const PoolingParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) T;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->kernel_size = {param->window_h_, param->window_w_};
|
||||
prim->strides = {param->stride_h_, param->stride_w_};
|
||||
prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_);
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
// Max pooling Definition
|
||||
MaxPoolM::MaxPoolM(const PoolingConfig &cfg) : PoolingM(cfg) {
|
||||
auto param = reinterpret_cast<PoolingParameter *>(OpParam());
|
||||
param->pool_mode_ = PoolMode_MaxPool;
|
||||
set_primitive(schema::PrimitiveType_MaxPoolFusion);
|
||||
set_name(UniqueName("MaxPool"));
|
||||
}
|
||||
|
||||
int MaxPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
return PoolingM::UnPopulate<schema::MaxPoolFusionT>(cnode);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> MaxPoolM::Grad(EXPR *yt) {
|
||||
auto in = yt;
|
||||
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
|
||||
in = TransposeM::TransposeCHW2HWC(yt);
|
||||
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
|
||||
PushOp(in->node());
|
||||
}
|
||||
auto pool_grad = new (std::nothrow) MaxPoolGradM(this);
|
||||
PushOp(pool_grad);
|
||||
return (*pool_grad)({input(0), output(0), in});
|
||||
}
|
||||
|
||||
static ImportReg maxPoolReg(schema::PrimitiveType_MaxPoolFusion, ReturnNode<MaxPoolM>);
|
||||
|
||||
// Avg pooling Definition
|
||||
AvgPoolM::AvgPoolM(const PoolingConfig &cfg) : PoolingM(cfg) {
|
||||
auto param = reinterpret_cast<PoolingParameter *>(OpParam());
|
||||
param->pool_mode_ = PoolMode_AvgPool;
|
||||
set_primitive(schema::PrimitiveType_AvgPoolFusion);
|
||||
set_name(UniqueName("AvgPool"));
|
||||
}
|
||||
|
||||
int AvgPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
return PoolingM::UnPopulate<schema::AvgPoolFusionT>(cnode);
|
||||
}
|
||||
|
||||
std::vector<EXPR *> AvgPoolM::Grad(EXPR *yt) {
|
||||
auto in = yt;
|
||||
if (yt->format() != NHWC && yt->dims().size() == C4NUM) {
|
||||
in = TransposeM::TransposeCHW2HWC(yt);
|
||||
in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name());
|
||||
PushOp(in->node());
|
||||
}
|
||||
auto pool_grad = new (std::nothrow) AvgPoolGradM(this);
|
||||
PushOp(pool_grad);
|
||||
return (*pool_grad)({input(0), output(0), in});
|
||||
}
|
||||
|
||||
static ImportReg avgPoolReg(schema::PrimitiveType_AvgPoolFusion, ReturnNode<AvgPoolM>);
|
||||
|
||||
// Max Pool Grad Definition
|
||||
MaxPoolGradM::MaxPoolGradM(MaxPoolM *node) {
|
||||
Node();
|
||||
CloneOpParam<PoolingParameter>(node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_MaxPoolGrad);
|
||||
set_name(kGradName + "/" + node->name() + "/MaxPoolGrad");
|
||||
}
|
||||
|
||||
int MaxPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
return PoolingM::UnPopulateGrad<schema::MaxPoolGradT>(cnode);
|
||||
}
|
||||
|
||||
// Avg Pool Grad Definition
|
||||
AvgPoolGradM::AvgPoolGradM(AvgPoolM *node) {
|
||||
Node();
|
||||
CloneOpParam<PoolingParameter>(node->OpParam());
|
||||
set_primitive(schema::PrimitiveType_AvgPoolGrad);
|
||||
set_name(kGradName + "/" + node->name() + "/AvgPoolGrad");
|
||||
}
|
||||
|
||||
int AvgPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
return PoolingM::UnPopulateGrad<schema::AvgPoolGradT>(cnode);
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *MaxPool2D(const PoolingConfig &cfg) {
|
||||
auto c = new (std::nothrow) MaxPoolM(cfg);
|
||||
if (c == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate max pool object";
|
||||
return nullptr;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
Node *AvgPool2D(const PoolingConfig &cfg) {
|
||||
auto c = new (std::nothrow) AvgPoolM(cfg);
|
||||
if (c == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate average pool object";
|
||||
return nullptr;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/cfg.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PoolingM : public Node {
|
||||
public:
|
||||
PoolingM() = default;
|
||||
explicit PoolingM(const PoolingConfig &cfg);
|
||||
template <typename T>
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode);
|
||||
template <typename T>
|
||||
int UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode);
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs);
|
||||
|
||||
private:
|
||||
void UpdateRoundMode(const PoolingParameter *param, enum schema::RoundMode *round_mode);
|
||||
int GetMode(std::string mode);
|
||||
};
|
||||
|
||||
class MaxPoolM : public PoolingM {
|
||||
public:
|
||||
MaxPoolM() = default;
|
||||
explicit MaxPoolM(const PoolingConfig &cfg);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class AvgPoolM : public PoolingM {
|
||||
public:
|
||||
AvgPoolM() = default;
|
||||
explicit AvgPoolM(const PoolingConfig &cfg);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class MaxPoolGradM : public PoolingM {
|
||||
public:
|
||||
explicit MaxPoolGradM(MaxPoolM *node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
class AvgPoolGradM : public PoolingM {
|
||||
public:
|
||||
explicit AvgPoolGradM(AvgPoolM *node);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_
|
|
@ -1,126 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/reduce.h"
|
||||
#include <functional>
|
||||
#include "src/expression/ops/tile.h"
|
||||
#include "src/expression/ops/reshape.h"
|
||||
#include "src/expression/ops/arithmetic.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/ops_utils.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "nnacl/reduce_parameter.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ReduceM::ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis) : Node() {
|
||||
expr()->SetSize(C2NUM);
|
||||
ReduceParameter *param = reinterpret_cast<ReduceParameter *>(calloc(1, sizeof(ReduceParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
param->mode_ = mode;
|
||||
param->keep_dims_ = keep_dims;
|
||||
param->reduce_to_end_ = false;
|
||||
param->coeff = 1.f;
|
||||
SetOpParam(param);
|
||||
set_name(UniqueName("Reduce"));
|
||||
set_primitive(schema::PrimitiveType_ReduceFusion);
|
||||
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(axis.size())}, kNumberTypeInt32, KHWC, "axis", axis.data());
|
||||
}
|
||||
|
||||
int ReduceM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::ReduceFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->keep_dims = reduce_param->keep_dims_;
|
||||
prim->mode = static_cast<schema::ReduceMode>(reduce_param->mode_);
|
||||
prim->coeff = reduce_param->coeff;
|
||||
prim->reduce_to_end = reduce_param->reduce_to_end_;
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ReduceM::Grad(EXPR *yt) {
|
||||
auto shape_of_x = input(0)->dims();
|
||||
std::vector<int> shape_of_axis;
|
||||
|
||||
auto data = input(1)->node()->data()->data().data();
|
||||
int size = input(1)->dims().at(0);
|
||||
auto int_data = reinterpret_cast<int *>(data);
|
||||
for (int i = 0; i < size; i++) {
|
||||
shape_of_axis.push_back(int_data[i]);
|
||||
}
|
||||
|
||||
// assume no dynamic shape
|
||||
ShapeReduce reduce_shape;
|
||||
auto output_shape_kept_dims = ShapeReduce()(shape_of_x, shape_of_axis);
|
||||
auto tile_scaling = VectorDiv()(shape_of_x, output_shape_kept_dims);
|
||||
auto reshape = NN::Reshape(output_shape_kept_dims);
|
||||
PushOp(reshape);
|
||||
reshape->set_name(name() + "/reshape");
|
||||
auto g = (*reshape)({yt}).front();
|
||||
auto tile = NN::Tile(tile_scaling);
|
||||
PushOp(tile);
|
||||
tile->set_name(name() + "/tile");
|
||||
auto sum_grad = (*tile)({g}).front();
|
||||
auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam());
|
||||
if (reduce_param->mode_ == schema::ReduceMode_ReduceSum) {
|
||||
return {sum_grad};
|
||||
} else if (reduce_param->mode_ == schema::ReduceMode_ReduceMean) {
|
||||
auto shape_of_y = output(0)->dims();
|
||||
auto shape_x_mul = std::accumulate(shape_of_x.begin(), shape_of_x.end(), 1, std::multiplies<int>());
|
||||
auto shape_y_mul = std::accumulate(shape_of_y.begin(), shape_of_y.end(), 1, std::multiplies<int>());
|
||||
auto div_shape = static_cast<float>(shape_x_mul) / static_cast<float>(shape_y_mul);
|
||||
auto div_op = NN::Div();
|
||||
PushOp(div_op);
|
||||
auto d = div_op->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "div_shape", &div_shape);
|
||||
auto dx = (*div_op)({sum_grad, d->expr()});
|
||||
return dx;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_ReduceFusion, ReturnNode<ReduceM>);
|
||||
|
||||
namespace NN {
|
||||
Node *ReduceSum(bool keep_dims, const std::vector<int> &axis) {
|
||||
auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceSum, keep_dims, axis);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate reduce sum node";
|
||||
return nullptr;
|
||||
}
|
||||
node->set_name(Node::UniqueName("ReduceSum"));
|
||||
return node;
|
||||
}
|
||||
Node *ReduceMean(bool keep_dims, const std::vector<int> &axis) {
|
||||
auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceMean, keep_dims, axis);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate reduce mean node";
|
||||
return nullptr;
|
||||
}
|
||||
node->set_name(Node::UniqueName("ReduceMean"));
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ReduceM : public Node {
|
||||
public:
|
||||
ReduceM() = default;
|
||||
ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis);
|
||||
Param *weight() override { return input(1)->node()->data(); }
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *ReduceMean(bool keep_dims, const std::vector<int> &axis);
|
||||
Node *ReduceSum(bool keep_dims, const std::vector<int> &axis);
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/reshape.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "nnacl/reshape_parameter.h"
|
||||
#include "src/expression/import.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ReshapeM::ReshapeM(const std::vector<int> &shape) : Node() {
|
||||
auto op_param = calloc(1, sizeof(ReshapeParameter));
|
||||
if (op_param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ReshapeParameter";
|
||||
return;
|
||||
}
|
||||
set_name(UniqueName("Reshape"));
|
||||
expr()->SetSize(C2NUM);
|
||||
SetOpParam(op_param);
|
||||
set_primitive(schema::PrimitiveType_Reshape);
|
||||
|
||||
ReshapeParameter *reshape_param = reinterpret_cast<ReshapeParameter *>(opParam_.get());
|
||||
reshape_param->shape_dim_ = shape.size();
|
||||
for (int i = 0; i < reshape_param->shape_dim_; i++) {
|
||||
reshape_param->shape_[i] = shape.at(i);
|
||||
}
|
||||
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(shape.size())}, kNumberTypeInt32, KHWC, "shape", shape.data());
|
||||
}
|
||||
|
||||
int ReshapeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::ReshapeT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> ReshapeM::Grad(EXPR *yt) {
|
||||
auto shape_of_x = input(0)->dims();
|
||||
auto reshape = NN::Reshape(shape_of_x);
|
||||
PushOp(reshape);
|
||||
return (*reshape)({yt});
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<ReshapeM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Reshape(const std::vector<int> &shape) {
|
||||
auto node = new (std::nothrow) ReshapeM(shape);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate reshape node";
|
||||
return nullptr;
|
||||
}
|
||||
node->set_name(Node::UniqueName("Reshape"));
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ReshapeM : public Node {
|
||||
public:
|
||||
ReshapeM() = default;
|
||||
explicit ReshapeM(const std::vector<int> &shape);
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_
|
|
@ -1,119 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/softmax.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops/reshape.h"
|
||||
#include "src/expression/ops/reduce.h"
|
||||
#include "src/expression/ops/arithmetic.h"
|
||||
#include "src/expression/ops/transpose.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/expression/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
SoftmaxM::SoftmaxM(int axis) {
|
||||
auto param = reinterpret_cast<SoftmaxParameter *>(calloc(1, sizeof(SoftmaxParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
param->axis_ = axis;
|
||||
SetOpParam(param);
|
||||
set_primitive(schema::PrimitiveType_Softmax);
|
||||
set_name(UniqueName("Softmax"));
|
||||
}
|
||||
|
||||
std::vector<int> SoftmaxM::getTransposeAxis(const std::vector<int> &shape, int axis) {
|
||||
int rank = shape.size();
|
||||
if (axis < 0) {
|
||||
axis += rank;
|
||||
}
|
||||
std::vector<int> reverse_axis(rank);
|
||||
std::iota(reverse_axis.begin(), reverse_axis.end(), 0);
|
||||
reverse_axis.at(axis) = rank - 1;
|
||||
reverse_axis.at(rank - 1) = axis;
|
||||
return reverse_axis;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> SoftmaxM::Grad(EXPR *yt) {
|
||||
auto x = input(0);
|
||||
auto out = output(0);
|
||||
auto shape_of_x = x->dims();
|
||||
auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam());
|
||||
auto reverse_axis = getTransposeAxis(shape_of_x, param->axis_);
|
||||
|
||||
auto transpose_out = NN::Transpose(reverse_axis);
|
||||
transpose_out->set_name(kGradName + "/" + name() + "/" + transpose_out->name() + "/out/");
|
||||
PushOp(transpose_out);
|
||||
auto y_trn = (*transpose_out)({out}).front();
|
||||
|
||||
auto transpose_dout = NN::Transpose(reverse_axis);
|
||||
transpose_dout->set_name(kGradName + "/" + name() + "/" + transpose_dout->name() + "/dout/");
|
||||
PushOp(transpose_dout);
|
||||
auto yt_trn = (*transpose_dout)({yt}).front();
|
||||
|
||||
auto mul0 = NN::Mul();
|
||||
mul0->set_name(kGradName + "/" + name() + "/" + mul0->name() + "0");
|
||||
PushOp(mul0);
|
||||
auto tmp0 = (*mul0)({y_trn, yt_trn}).front();
|
||||
|
||||
auto sum_func = NN::ReduceSum(true, {-1});
|
||||
sum_func->set_name(kGradName + "/" + name() + "/" + sum_func->name());
|
||||
PushOp(sum_func);
|
||||
auto tmp1 = (*sum_func)({tmp0}).front();
|
||||
|
||||
auto sub = NN::Sub();
|
||||
sub->set_name(kGradName + "/" + name() + "/" + sub->name());
|
||||
PushOp(sub);
|
||||
auto tmp2 = (*sub)({yt_trn, tmp1}).front();
|
||||
|
||||
auto mul1 = NN::Mul();
|
||||
mul1->set_name(kGradName + "/" + name() + "/" + mul1->name() + "1");
|
||||
PushOp(mul1);
|
||||
auto tmp3 = (*mul1)({y_trn, tmp2});
|
||||
|
||||
auto transpose_dx = NN::Transpose(reverse_axis);
|
||||
transpose_dx->set_name(kGradName + "/" + name() + "/" + transpose_dx->name() + "/dx");
|
||||
PushOp(transpose_dx);
|
||||
auto dx = (*transpose_dx)({tmp3});
|
||||
return dx;
|
||||
}
|
||||
|
||||
int SoftmaxM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::SoftmaxT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->axis.push_back(param->axis_);
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Softmax, ReturnNode<SoftmaxM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Softmax(int axis) {
|
||||
auto node = new (std::nothrow) SoftmaxM(axis);
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SoftmaxM : public Node {
|
||||
public:
|
||||
SoftmaxM() = default;
|
||||
explicit SoftmaxM(int axis);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
|
||||
private:
|
||||
std::vector<int> getTransposeAxis(const std::vector<int> &shape, int axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/softmaxCE.h"
|
||||
#include "include/api/net.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/expression/ops/reduce.h"
|
||||
namespace mindspore {
|
||||
namespace NN {
|
||||
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) {
|
||||
auto lite_node = lite::NN::SoftmaxCrossEntropy(cfg);
|
||||
return NodeImpl::Connect(lite_node);
|
||||
}
|
||||
} // namespace NN
|
||||
|
||||
namespace lite {
|
||||
SoftmaxCrossEntropyM::SoftmaxCrossEntropyM() {
|
||||
auto param = calloc(1, sizeof(OpParameter));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate parameter";
|
||||
return;
|
||||
}
|
||||
expr()->SetSize(C2NUM);
|
||||
SetOpParam(param);
|
||||
set_name("SoftmaxCrossEntropy");
|
||||
set_primitive(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits);
|
||||
EXPR e(this);
|
||||
e.SetSize(0);
|
||||
expr_.emplace_back(e);
|
||||
}
|
||||
|
||||
Node *SoftmaxCrossEntropyM::GetReductionNode(const std::string &mode, const std::vector<int> &axis) {
|
||||
if (mode == "mean") {
|
||||
return NN::ReduceMean(false, axis);
|
||||
} else if (mode == "sum") {
|
||||
return NN::ReduceSum(false, axis);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SoftmaxCrossEntropyM::SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg) : SoftmaxCrossEntropyM() {
|
||||
std::vector<int> axis = {0};
|
||||
reduce_ = GetReductionNode(cfg.reduction, axis);
|
||||
if (reduce_ != nullptr) {
|
||||
PushOp(reduce_);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<EXPR *> SoftmaxCrossEntropyM::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto y = Node::construct(inputs);
|
||||
if (reduce_ != nullptr) {
|
||||
y = (*reduce_)({y.front()});
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> SoftmaxCrossEntropyM::Grad(EXPR *expr) { return {this->expr(1)}; }
|
||||
|
||||
int SoftmaxCrossEntropyM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::SoftmaxCrossEntropyWithLogitsT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) {
|
||||
auto s = new (std::nothrow) SoftmaxCrossEntropyM(cfg);
|
||||
if (s == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate softmax node";
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "include/api/net.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SoftmaxCrossEntropyM : public Node {
|
||||
public:
|
||||
SoftmaxCrossEntropyM();
|
||||
explicit SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg);
|
||||
std::vector<EXPR *> Grad(EXPR *expr) override;
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
|
||||
private:
|
||||
Node *GetReductionNode(const std::string &mode, const std::vector<int> &axis);
|
||||
Node *reduce_ = nullptr;
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_
|
|
@ -1,62 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/tile.h"
|
||||
#include <memory>
|
||||
#include "src/expression/ops.h"
|
||||
#include "nnacl/base/tile_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
TileM::TileM(const std::vector<int> &multiples) : Node() {
|
||||
expr()->SetSize(C2NUM);
|
||||
TileParameter *param = reinterpret_cast<TileParameter *>(calloc(1, sizeof(TileParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << " cannot allocate ConvParameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(param);
|
||||
set_name(UniqueName("Tile"));
|
||||
set_primitive(schema::PrimitiveType_TileFusion);
|
||||
Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(multiples.size())}, kNumberTypeInt32, KHWC, "axis",
|
||||
multiples.data());
|
||||
}
|
||||
|
||||
int TileM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto tile_param = reinterpret_cast<const TileParameter *>(OpParam());
|
||||
auto prim = new (std::nothrow) schema::TileFusionT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < tile_param->dims_size_; i++) {
|
||||
prim->dims.push_back(tile_param->dims_[i]);
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace NN {
|
||||
Node *Tile(const std::vector<int> &multiples) {
|
||||
auto node = new (std::nothrow) TileM(multiples);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate tile node";
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,40 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TileM : public Node {
|
||||
public:
|
||||
TileM() = default;
|
||||
explicit TileM(const std::vector<int> &multiples);
|
||||
Param *weight() override { return input(1)->node()->data(); }
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *Tile(const std::vector<int> &multiples);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_
|
|
@ -1,88 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops/transpose.h"
|
||||
#include <memory>
|
||||
#include "nnacl/transpose_parameter.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/expression/import.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
TransposeM::TransposeM(const std::vector<int> &vector) {
|
||||
auto param = calloc(1, sizeof(TransposeParameter));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate transpose parameter";
|
||||
return;
|
||||
}
|
||||
SetOpParam(param);
|
||||
expr()->SetSize(C2NUM);
|
||||
set_primitive(schema::PrimitiveType_Transpose);
|
||||
std::vector<int> dims = {static_cast<int>(vector.size())};
|
||||
set_name(UniqueName("Transpose"));
|
||||
CreateConstTensor(C1NUM, dims, kNumberTypeInt32, KHWC, "axis", vector.data());
|
||||
}
|
||||
|
||||
std::vector<int> TransposeM::Invert(const std::vector<int> &vector) {
|
||||
std::vector<int> res;
|
||||
for (size_t i = 0; i < vector.size(); i++) {
|
||||
int idx = static_cast<int>(i);
|
||||
auto val = std::find_if(vector.begin(), vector.end(), [idx](int x) { return (x == idx) ? true : false; });
|
||||
if (val == vector.end()) {
|
||||
MS_LOG(ERROR) << "Wrong index for " << idx;
|
||||
return {};
|
||||
}
|
||||
res.push_back(std::distance(vector.begin(), val));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<EXPR *> TransposeM::Grad(EXPR *yt) {
|
||||
auto tensor = input(1)->node();
|
||||
auto data = tensor->data();
|
||||
auto vec = data->Extract<int>();
|
||||
auto invert = Invert(vec);
|
||||
auto tran = new (std::nothrow) TransposeM(invert);
|
||||
if (tran == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate transpose grad";
|
||||
return {};
|
||||
}
|
||||
tran->set_name(kGradName + "/" + name() + "/" + tran->name());
|
||||
PushOp(tran);
|
||||
auto grad = (*tran)({yt});
|
||||
return grad;
|
||||
}
|
||||
|
||||
int TransposeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) {
|
||||
auto prim = new (std::nothrow) schema::TransposeT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate primitive";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->primitive->value.value = prim;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static ImportReg reg(schema::PrimitiveType_Transpose, ReturnNode<TransposeM>);
|
||||
|
||||
namespace NN {
|
||||
Node *Transpose(const std::vector<int> &permute) {
|
||||
auto node = new (std::nothrow) TransposeM(permute);
|
||||
return node;
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/expression/node.h"
|
||||
#include "inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TransposeM : public Node {
|
||||
public:
|
||||
TransposeM() = default;
|
||||
explicit TransposeM(const std::vector<int> &vector);
|
||||
static EXPR *TransposeCHW2HWC(EXPR *in) {
|
||||
std::vector<int> res = {0, 2, 3, 1};
|
||||
auto trans = new (std::nothrow) TransposeM(res);
|
||||
if (trans == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return (*trans)({in}).front();
|
||||
}
|
||||
static EXPR *TransposeHWC2CHW(EXPR *in) {
|
||||
std::vector<int> res = {0, 3, 1, 2};
|
||||
auto trans = new (std::nothrow) TransposeM(res);
|
||||
if (trans == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return (*trans)({in}).front();
|
||||
}
|
||||
int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override;
|
||||
std::vector<EXPR *> Grad(EXPR *yt) override;
|
||||
|
||||
private:
|
||||
std::vector<int> Invert(const std::vector<int> &vec);
|
||||
};
|
||||
|
||||
namespace NN {
|
||||
Node *Transpose(const std::vector<int> &permute);
|
||||
} // namespace NN
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_
|
|
@ -1,275 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/ops_utils.h"
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
enum class State {
|
||||
SAME,
|
||||
X_ONE,
|
||||
Y_ONE,
|
||||
};
|
||||
|
||||
bool CompareShape(const std::vector<int> &x_shape, const std::vector<int> &y_shape) {
|
||||
if (x_shape.size() != y_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (x_shape.at(i) != y_shape.at(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
|
||||
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
|
||||
MS_ASSERT(grad_x_reduce_idx != nullptr);
|
||||
MS_ASSERT(grad_y_reduce_idy != nullptr);
|
||||
const size_t n = reverse_x.size();
|
||||
if (reverse_y.size() < n) {
|
||||
MS_LOG_ERROR << "The size of reverse_y is less than the size of reverse_x.";
|
||||
}
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
State curr = State::SAME;
|
||||
const int x_i = reverse_x[i];
|
||||
const int y_i = reverse_y[i];
|
||||
const int reduce_idx = (n - 1 - i);
|
||||
if (x_i == y_i) {
|
||||
curr = State::SAME;
|
||||
} else if (x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
curr = State::X_ONE;
|
||||
} else if (y_i == 1) {
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
curr = State::Y_ONE;
|
||||
} else {
|
||||
MS_LOG_ERROR << "not compatible shape input for BroadcastGradientArgs";
|
||||
}
|
||||
if (curr == State::SAME && x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
|
||||
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> BroadcastGradientArgs::operator()() {
|
||||
std::vector<std::vector<int>> input_dim(kInNum);
|
||||
input_dim[0] = dim0_;
|
||||
input_dim[1] = dim1_;
|
||||
auto same_shape = CompareShape(dim0_, dim1_);
|
||||
if (same_shape) {
|
||||
return {{}, {}};
|
||||
}
|
||||
|
||||
std::vector<int> reverse_x;
|
||||
std::vector<int> reverse_y;
|
||||
|
||||
(void)std::transform(dim0_.rbegin(), dim0_.rend(), std::back_inserter(reverse_x), [](const int &v) { return v; });
|
||||
(void)std::transform(dim1_.rbegin(), dim1_.rend(), std::back_inserter(reverse_y), [](const int &v) { return v; });
|
||||
|
||||
if (reverse_x.size() > reverse_y.size()) {
|
||||
reverse_y.resize(reverse_x.size(), 1);
|
||||
} else {
|
||||
reverse_x.resize(reverse_y.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> grad_x_reduce_idx;
|
||||
std::vector<int> grad_y_reduce_idy;
|
||||
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
|
||||
return {grad_x_reduce_idx, grad_y_reduce_idy};
|
||||
}
|
||||
|
||||
void DynamicBroadcastGradientArgs::AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx,
|
||||
std::vector<bool> current_is_one, bool none_is_one,
|
||||
const size_t largest_rank, size_t j) {
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
(void)(*grad_reduce_idx)[i].emplace_back(largest_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicBroadcastGradientArgs::UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) {
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
(*prev_is_one)[i] = current_is_one[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::GetGradientIndices(
|
||||
const std::vector<std::vector<int>> &reverse_shape, const size_t largest_rank) {
|
||||
std::vector<std::vector<int>> grad_reduce_idx(kInNum);
|
||||
// indices of j-th component of each input.
|
||||
std::vector<bool> prev_is_one(kInNum);
|
||||
std::vector<bool> current_is_one(kInNum);
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
prev_is_one[i] = false;
|
||||
current_is_one[i] = false;
|
||||
}
|
||||
|
||||
bool set_one = false;
|
||||
for (size_t j = 0; j < largest_rank; ++j) {
|
||||
int output_dim = -1;
|
||||
bool output_dim_set = false;
|
||||
bool none_is_one = true;
|
||||
// Find which indices are 1.
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
if (reverse_shape[i][j] == 1) {
|
||||
current_is_one[i] = true;
|
||||
none_is_one = false;
|
||||
} else {
|
||||
current_is_one[i] = false;
|
||||
if (!output_dim_set || reverse_shape[i][j] == static_cast<int>(output_dim)) {
|
||||
output_dim = reverse_shape[i][j];
|
||||
output_dim_set = true;
|
||||
} else {
|
||||
std::cout << "Input[0] and input[1] Cannot broadcast!";
|
||||
}
|
||||
}
|
||||
}
|
||||
// All dimensions are 1.
|
||||
if (!output_dim_set) {
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
|
||||
}
|
||||
continue;
|
||||
} else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) {
|
||||
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
|
||||
} else {
|
||||
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
|
||||
}
|
||||
set_one = true;
|
||||
UpdatePreIsOne(&prev_is_one, current_is_one);
|
||||
}
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::CalculateOutput(const std::vector<std::vector<int>> &x) {
|
||||
std::vector<std::vector<int>> grad_reduce_idx(kInNum);
|
||||
bool all_equal = true;
|
||||
size_t largest_rank = 0;
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
if (x[i] != x[0]) {
|
||||
all_equal = false;
|
||||
}
|
||||
if (x[i].size() > largest_rank) {
|
||||
largest_rank = x[i].size();
|
||||
}
|
||||
}
|
||||
if (all_equal) {
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
// Reverse input the shapes
|
||||
std::vector<std::vector<int>> reverse_shape(kInNum);
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
reverse_shape[i] = x[i];
|
||||
std::reverse(reverse_shape[i].begin(), reverse_shape[i].end());
|
||||
}
|
||||
|
||||
// 1-extend and align all vectors.
|
||||
for (size_t i = 0; i < kInNum; ++i) {
|
||||
if (reverse_shape[i].size() < largest_rank) {
|
||||
reverse_shape[i].resize(largest_rank, 1);
|
||||
}
|
||||
}
|
||||
grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank);
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::SetOutputValue(
|
||||
const std::vector<std::vector<int>> &grad_reduce_idx, const std::vector<std::vector<int>> &input_dim) {
|
||||
std::vector<std::vector<int>> output(kInNum);
|
||||
for (size_t index = 0; index < kInNum; ++index) {
|
||||
auto idx_num = grad_reduce_idx[index].size();
|
||||
for (size_t k = 0; k < idx_num; ++k) {
|
||||
output[index].push_back(grad_reduce_idx[index][idx_num - 1 - k]);
|
||||
}
|
||||
if (idx_num == 0) {
|
||||
auto input_num = input_dim[index].size();
|
||||
for (size_t k = 0; k < input_num; ++k) {
|
||||
output[index].push_back(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> DynamicBroadcastGradientArgs::operator()() {
|
||||
std::vector<std::vector<int>> input_dim(kInNum);
|
||||
input_dim[0] = dim0_;
|
||||
input_dim[1] = dim1_;
|
||||
auto grad_reduce_idx = CalculateOutput(input_dim);
|
||||
auto output = SetOutputValue(grad_reduce_idx, input_dim);
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<int> VectorDiv::operator()(const std::vector<int> &x, const std::vector<int> &d) {
|
||||
if (d.size() != x.size()) {
|
||||
MS_LOG(ERROR) << "x and divider must have same size";
|
||||
return {};
|
||||
}
|
||||
std::vector<int> res;
|
||||
for (size_t i = 0; i < d.size(); i++) {
|
||||
auto x_value = x.at(i);
|
||||
auto d_value = d.at(i);
|
||||
if (d_value == 0) {
|
||||
MS_LOG(ERROR) << "Divisor is zero";
|
||||
return {};
|
||||
}
|
||||
if ((x_value % d_value) != 0) {
|
||||
MS_LOG(ERROR) << "x and d and not dividable";
|
||||
}
|
||||
auto r = x_value / d_value;
|
||||
res.push_back(r);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int> ShapeReduce::operator()(const std::vector<int> &x_shape, const std::vector<int> &axis) {
|
||||
int x_rank = x_shape.size();
|
||||
std::set<int> axis_set;
|
||||
|
||||
auto min = -x_rank;
|
||||
auto max = x_rank - 1;
|
||||
for (auto &elem : axis) {
|
||||
if (elem > max || elem < min) {
|
||||
MS_LOG(ERROR) << "illegal axis value";
|
||||
return {};
|
||||
}
|
||||
axis_set.insert(elem);
|
||||
}
|
||||
std::vector<int> res;
|
||||
for (int i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(i) || axis_set.count(i - x_rank)) {
|
||||
res.push_back(1);
|
||||
} else {
|
||||
res.push_back(x_shape.at(i));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,69 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_
|
||||
|
||||
#include "include/api/cfg.h"
|
||||
#include "src/expression/net.h"
|
||||
#include "vector"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BroadcastGradientArgs {
|
||||
public:
|
||||
BroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {}
|
||||
std::vector<std::vector<int>> operator()();
|
||||
|
||||
private:
|
||||
static const int kInNum = 2;
|
||||
const std::vector<int> &dim0_;
|
||||
const std::vector<int> &dim1_;
|
||||
};
|
||||
|
||||
class DynamicBroadcastGradientArgs {
|
||||
public:
|
||||
DynamicBroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {}
|
||||
std::vector<std::vector<int>> operator()();
|
||||
|
||||
private:
|
||||
void AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx, std::vector<bool> current_is_one,
|
||||
bool none_is_one, const size_t largest_rank, size_t j);
|
||||
void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one);
|
||||
std::vector<std::vector<int>> GetGradientIndices(const std::vector<std::vector<int>> &reverse_shape,
|
||||
const size_t largest_rank);
|
||||
std::vector<std::vector<int>> CalculateOutput(const std::vector<std::vector<int>> &x);
|
||||
std::vector<std::vector<int>> SetOutputValue(const std::vector<std::vector<int>> &grad_reduce_idx,
|
||||
const std::vector<std::vector<int>> &input_dim);
|
||||
static const int kInNum = 2;
|
||||
const std::vector<int> &dim0_;
|
||||
const std::vector<int> &dim1_;
|
||||
};
|
||||
|
||||
class VectorDiv {
|
||||
public:
|
||||
VectorDiv() {}
|
||||
std::vector<int> operator()(const std::vector<int> &x, const std::vector<int> &d);
|
||||
};
|
||||
|
||||
class ShapeReduce {
|
||||
public:
|
||||
ShapeReduce() {}
|
||||
std::vector<int> operator()(const std::vector<int> &x_shape, const std::vector<int> &axis);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_
|
|
@ -1,70 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/expression/param.h"
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
constexpr float kZero = 0.0f;
|
||||
constexpr float kOne = 1.0f;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int Param::Fill(Mode mode) {
|
||||
std::default_random_engine engine{static_cast<unsigned int>(0)};
|
||||
std::vector<float> data(size_);
|
||||
switch (mode) {
|
||||
case NORMAL: {
|
||||
constexpr float scale = 0.01;
|
||||
std::normal_distribution<float> n{0, 1};
|
||||
std::generate_n(data.begin(), size_, [&]() { return n(engine); });
|
||||
(void)std::transform(data.begin(), data.end(), data.begin(), [=](float x) { return x * scale; });
|
||||
break;
|
||||
}
|
||||
case UNIFORM: {
|
||||
constexpr float scale = 0.07;
|
||||
std::uniform_real_distribution<float> u{-1.0, 1.0};
|
||||
std::generate_n(data.begin(), size_, [&]() { return u(engine) * scale; });
|
||||
break;
|
||||
}
|
||||
case ZEROS:
|
||||
std::fill_n(data.begin(), size_, kZero);
|
||||
break;
|
||||
case ONES:
|
||||
std::fill_n(data.begin(), size_, kOne);
|
||||
break;
|
||||
case NOT_SUPPORTED:
|
||||
return RET_ERROR;
|
||||
}
|
||||
Copy(data);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Param::Mode Param::String2Enum(std::string mode) {
|
||||
(void)std::transform(mode.begin(), mode.end(), mode.begin(), ::tolower);
|
||||
if (mode == "normal") return NORMAL;
|
||||
if (mode == "uniform") return UNIFORM;
|
||||
if (mode == "ones") return ONES;
|
||||
if (mode == "zeors") return ZEROS;
|
||||
return NOT_SUPPORTED;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,60 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Param {
|
||||
public:
|
||||
enum Mode { NORMAL, UNIFORM, ONES, ZEROS, NOT_SUPPORTED };
|
||||
int Fill(Mode type);
|
||||
static Mode String2Enum(std::string);
|
||||
std::vector<uint8_t> &data() { return data_; }
|
||||
size_t Load(std::string file_name, size_t offset = 0) { return data_.size() * sizeof(float); }
|
||||
size_t Load(std::ifstream &s, int offset = 0) { return data_.size() * sizeof(float); }
|
||||
void SetSize(size_t size) { size_ = size; }
|
||||
template <typename T>
|
||||
void Copy(const T *data, size_t size) {
|
||||
auto cast_data = reinterpret_cast<const uint8_t *>(data);
|
||||
data_ = decltype(data_)(cast_data, cast_data + size * sizeof(T) / sizeof(uint8_t));
|
||||
}
|
||||
template <typename T>
|
||||
void Copy(const std::vector<T> data) {
|
||||
Copy<T>(data.data(), data.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> Extract() {
|
||||
T *num = reinterpret_cast<T *>(data_.data());
|
||||
std::vector<T> res(num, num + data_.size() / sizeof(T));
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
std::vector<uint8_t> data_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_
|
|
@ -1,30 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/expression/sequential.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void Sequential::Add(Node *node) { PushOp(node); }
|
||||
|
||||
std::vector<EXPR *> Sequential::construct(const std::vector<EXPR *> &inputs) {
|
||||
auto x = inputs;
|
||||
for (auto &node : ops_) {
|
||||
x = (*node)({x.front()});
|
||||
}
|
||||
return x;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_
|
||||
#define MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_
|
||||
#include <vector>
|
||||
#include "src/expression/net.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Sequential : public Net {
|
||||
public:
|
||||
std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override;
|
||||
void Add(Node *node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_
|
|
@ -211,12 +211,6 @@ Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_conte
|
|||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs,
|
||||
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
|
|
|
@ -1,145 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "include/api/net.h"
|
||||
#include "include/api/status.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/expression/cfg.h"
|
||||
|
||||
namespace mindspore {
|
||||
uint32_t Node::type() { return kNodeType; }
|
||||
|
||||
std::vector<Expr *> Node::operator()(const std::vector<Expr *> &inputs) {
|
||||
auto in = Expr::convert(inputs);
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "empty implementation";
|
||||
return {};
|
||||
}
|
||||
if (impl_->node() == nullptr) {
|
||||
MS_LOG(ERROR) << "expression node is not attached";
|
||||
return {};
|
||||
}
|
||||
auto out = impl_->node()->construct(in);
|
||||
return Expr::convert(out);
|
||||
}
|
||||
|
||||
Expr *Node::Create(std::string name) {
|
||||
auto expr = impl_->node()->create(name);
|
||||
return reinterpret_cast<Expr *>(expr);
|
||||
}
|
||||
|
||||
Node::Node() {
|
||||
auto impl = std::make_shared<NodeImpl>();
|
||||
impl_ = impl;
|
||||
impl_->set_pnode(this);
|
||||
}
|
||||
|
||||
Node::~Node() {
|
||||
impl_->set_pnode(nullptr);
|
||||
auto node = impl_->node();
|
||||
if (node != nullptr) {
|
||||
impl_->set_node(nullptr);
|
||||
delete node;
|
||||
}
|
||||
}
|
||||
|
||||
Net::Net(std::string name) {
|
||||
auto impl = std::make_shared<NetImpl>();
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate network implementation";
|
||||
return;
|
||||
}
|
||||
impl_ = impl;
|
||||
impl_->set_pnet(std::shared_ptr<Net>(this));
|
||||
auto netl = new (std::nothrow) lite::Net(name);
|
||||
if (netl == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate network lite";
|
||||
return;
|
||||
}
|
||||
netl->set_impl(impl);
|
||||
impl_->set_net(netl);
|
||||
}
|
||||
|
||||
Net::Net() : Net("") {}
|
||||
|
||||
Net::Net(const Graph &g) {
|
||||
auto net = NetImpl::GetNet(g);
|
||||
impl_ = net->impl_;
|
||||
}
|
||||
|
||||
void Net::Add(NetBase *element) { MS_LOG(WARNING) << "Only sequential can add element"; }
|
||||
|
||||
uint32_t Net::type() { return kNetType; }
|
||||
|
||||
std::vector<Expr *> Net::construct(const std::vector<Expr *> &inputs) {
|
||||
auto in = Expr::convert(inputs);
|
||||
auto out = impl_->net()->construct(in);
|
||||
return Expr::convert(out);
|
||||
}
|
||||
|
||||
std::vector<Expr *> Net::operator()(const std::vector<Expr *> &inputs) {
|
||||
auto in = Expr::convert(inputs);
|
||||
auto x = construct(inputs);
|
||||
impl_->net()->input_ = in;
|
||||
auto out = Expr::convert(x);
|
||||
impl_->net()->output_ = out;
|
||||
impl_->net()->real_output_ = out;
|
||||
return x;
|
||||
}
|
||||
void Net::Register(Net *net, std::string &&name) {
|
||||
if (net != nullptr) {
|
||||
auto net_lite = net->impl_->net();
|
||||
impl_->net()->Register(net_lite, std::move(name));
|
||||
}
|
||||
}
|
||||
|
||||
void Net::Register(Node *node, std::string &&name) {
|
||||
if (node != nullptr) {
|
||||
auto impl = NodeImpl::GetImpl(node);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing implementation";
|
||||
return;
|
||||
}
|
||||
auto node_lite = impl->node();
|
||||
impl_->net()->Register(node_lite, std::move(name));
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeSet> Net::trainable_params() {
|
||||
auto node_set = std::make_shared<NodeSet>();
|
||||
if (node_set == nullptr) {
|
||||
MS_LOG(ERROR) << "new NodeSet failed.";
|
||||
return nullptr;
|
||||
}
|
||||
node_set->set_ = impl_->net()->trainable_params();
|
||||
return node_set;
|
||||
}
|
||||
|
||||
const std::vector<int> Net::InputShape(int idx) { return impl_->InputShape(idx); }
|
||||
const std::vector<int> Net::OutputShape(int idx) { return impl_->OutputShape(idx); }
|
||||
|
||||
Net::~Net() {
|
||||
if (impl_ != nullptr) {
|
||||
if ((impl_->pnet() == nullptr) || (impl_->pnet() == this)) {
|
||||
impl_->set_pnet(nullptr);
|
||||
impl_->set_net(nullptr);
|
||||
impl_.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -1,220 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "include/api/serialization.h"
|
||||
#include "src/expression/import.h"
|
||||
#include "src/expression/ops.h"
|
||||
#include "src/litert/cxx_api/model/model_impl.h"
|
||||
|
||||
namespace {
|
||||
constexpr size_t kFlatbuffersBuilderInitSize = 1024;
|
||||
};
|
||||
|
||||
namespace mindspore {
|
||||
Sequential::Sequential() {}
|
||||
|
||||
lite::Node *Sequential::GetNode(NetBase *element) {
|
||||
lite::Node *lite_node = nullptr;
|
||||
switch (element->type()) {
|
||||
case kNodeType: {
|
||||
Node *node = reinterpret_cast<Node *>(element);
|
||||
auto impl = NodeImpl::GetImpl(node);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot find node implement";
|
||||
return nullptr;
|
||||
}
|
||||
lite_node = impl->node();
|
||||
break;
|
||||
}
|
||||
case kNetType: {
|
||||
auto net = reinterpret_cast<Net *>(element);
|
||||
auto impl = NetImpl::GetImpl(net);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot find node implement";
|
||||
return nullptr;
|
||||
}
|
||||
lite_node = impl->net();
|
||||
break;
|
||||
}
|
||||
}
|
||||
return lite_node;
|
||||
}
|
||||
|
||||
void Sequential::Add(NetBase *element) {
|
||||
lite::Node *node = GetNode(element);
|
||||
auto impl = NetImpl::GetImpl(this);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "No implementation";
|
||||
return;
|
||||
}
|
||||
impl->net()->Add(node);
|
||||
}
|
||||
|
||||
NetWithLoss::NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) {
|
||||
REG(net_);
|
||||
Register(loss_fn_, "_loss_fn");
|
||||
}
|
||||
|
||||
std::vector<Expr *> NetWithLoss::construct(const std::vector<Expr *> &inputs) {
|
||||
if (inputs.size() != C2NUM) {
|
||||
MS_LOG(ERROR) << "need 2 inputs for loss";
|
||||
return {};
|
||||
}
|
||||
auto input = inputs[FIRST_INPUT];
|
||||
auto label = inputs[SECOND_INPUT];
|
||||
auto x = (*net_)({input});
|
||||
x = (*loss_fn_)({x[FIRST_INPUT], label});
|
||||
return x;
|
||||
}
|
||||
|
||||
NetImpl::NetImpl(std::shared_ptr<Net> p) { pnet_ = p; }
|
||||
|
||||
NetImpl::NetImpl(Graph *g) { pnet_ = g->net_data_->net(); }
|
||||
|
||||
std::vector<lite::EXPR *> NetImpl::construct(const std::vector<lite::EXPR *> &inputs) {
|
||||
auto in = Expr::convert(inputs);
|
||||
auto out = pnet_->construct(in);
|
||||
return Expr::convert(out);
|
||||
}
|
||||
|
||||
Net *NetImpl::Connect(std::shared_ptr<Net> net, lite::Net *lnet) {
|
||||
auto impl = GetImpl(net.get());
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing implementation";
|
||||
return nullptr;
|
||||
}
|
||||
impl->set_pnet(net);
|
||||
lnet->set_impl(impl);
|
||||
impl->set_net(lnet);
|
||||
return net.get();
|
||||
}
|
||||
|
||||
Status NetImpl::Import(const char *model_buf, Graph *graph) {
|
||||
auto mg = schema::GetMetaGraph(model_buf);
|
||||
auto net = new (std::nothrow) Net();
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate network";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
lite::Import import;
|
||||
auto lite_net = import.ImportMs(mg);
|
||||
if (lite_net == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to import net";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
lite_net->SetRealOutput();
|
||||
Connect(net->shared_from_this(), lite_net);
|
||||
*graph = Graph(net);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status NetImpl::TrainNet(Node *optimizer, const std::vector<Expr *> &inputs) {
|
||||
auto impl = NodeImpl::GetImpl(optimizer);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing implementation ";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto opt = impl->node();
|
||||
auto in = Expr::convert(inputs);
|
||||
auto ret_net = net()->TrainNet(opt, in);
|
||||
if (ret_net == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to train network";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> NetImpl::MakeMs() {
|
||||
auto mgraph = std::make_unique<Graph>(Graph::Type::kExecutableGraph);
|
||||
if (mgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate graph";
|
||||
return nullptr;
|
||||
}
|
||||
auto trained_graph = net()->MakeMs();
|
||||
if (trained_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot create flat buffer";
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||
auto offset = schema::MetaGraph::Pack(builder, trained_graph.get());
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
auto buffer = builder.GetBufferPointer();
|
||||
size_t size = builder.GetSize();
|
||||
auto status = Serialization::Load(buffer, size, mindspore::kMindIR, mgraph.get());
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "failed to load flatbuffer to graph";
|
||||
return nullptr;
|
||||
}
|
||||
return mgraph;
|
||||
}
|
||||
|
||||
const std::vector<int> NetImpl::InputShape(int idx) { return net_->InputShape(idx); }
|
||||
|
||||
const std::vector<int> NetImpl::OutputShape(int idx) { return net_->OutputShape(idx); }
|
||||
|
||||
void NetImpl::ReplaceNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->net().swap(n); }
|
||||
|
||||
ExpressionLoader expression_registrator = CreateExpressionLoader(NetImpl::Import);
|
||||
|
||||
namespace NN {
|
||||
Net *Sequential() {
|
||||
auto net = new (std::nothrow) mindspore::Sequential();
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate ";
|
||||
return nullptr;
|
||||
}
|
||||
auto netl = lite::NN::Sequential();
|
||||
return NetImpl::Connect(net->shared_from_this(), netl);
|
||||
}
|
||||
|
||||
Net *NetWithLoss(Net *net, Node *loss) {
|
||||
auto loss_net = new (std::nothrow) mindspore::NetWithLoss(net, loss);
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate loss net";
|
||||
return nullptr;
|
||||
}
|
||||
return loss_net;
|
||||
}
|
||||
|
||||
Graph *GraphWithLoss(Graph *graph, Node *loss) {
|
||||
auto net = NetImpl::GetNet(*graph);
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate network";
|
||||
return nullptr;
|
||||
}
|
||||
auto loss_net = NetWithLoss(net.get(), loss);
|
||||
if (loss_net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate network";
|
||||
return nullptr;
|
||||
}
|
||||
NetImpl::ReplaceNet(graph, loss_net->shared_from_this());
|
||||
return graph;
|
||||
}
|
||||
|
||||
Net *NetWithLoss(Graph *g, Node *loss) {
|
||||
auto net = new (std::nothrow) Net(*g);
|
||||
if (net == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate net";
|
||||
return nullptr;
|
||||
}
|
||||
return NetWithLoss(net, loss);
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace mindspore
|
|
@ -1,95 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "include/api/cfg.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/net.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/litert/cxx_api/graph/net_data.h"
|
||||
#include "src/expression/net.h"
|
||||
#include "src/expression/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr uint32_t kNodeType = 1;
|
||||
constexpr uint32_t kNetType = 2;
|
||||
class Sequential : public Net {
|
||||
public:
|
||||
Sequential();
|
||||
void Add(NetBase *n) override;
|
||||
|
||||
private:
|
||||
std::vector<NetBase *> ops_;
|
||||
lite::Node *GetNode(NetBase *element);
|
||||
};
|
||||
|
||||
class NetWithLoss : public Net {
|
||||
public:
|
||||
NetWithLoss(Net *net, Node *loss);
|
||||
std::vector<Expr *> construct(const std::vector<Expr *> &inputs) override;
|
||||
|
||||
private:
|
||||
Net *net_{nullptr};
|
||||
Node *loss_fn_{nullptr};
|
||||
};
|
||||
|
||||
class MS_API NetImpl {
|
||||
public:
|
||||
virtual ~NetImpl() {}
|
||||
explicit NetImpl(std::shared_ptr<Net> p);
|
||||
explicit NetImpl(Graph *g);
|
||||
NetImpl() = default;
|
||||
void set_net(lite::Net *net) {
|
||||
if (net_ != nullptr) {
|
||||
net_->set_impl(nullptr);
|
||||
delete net_;
|
||||
}
|
||||
net_ = net;
|
||||
}
|
||||
void erase_net() { net_ = nullptr; }
|
||||
void set_pnet(std::shared_ptr<Net> net) { pnet_ = net; }
|
||||
Net *pnet() { return pnet_.get(); }
|
||||
lite::Net *net() { return net_; }
|
||||
|
||||
std::vector<lite::EXPR *> construct(const std::vector<lite::EXPR *> &inputs);
|
||||
static std::shared_ptr<mindspore::NetImpl> &GetImpl(Net *net) { return net->impl_; }
|
||||
static Net *Connect(std::shared_ptr<Net> net, lite::Net *lnet);
|
||||
static std::shared_ptr<Net> &GetNet(const Graph &g) { return g.net_data_->net(); }
|
||||
static void SetNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->set_net(n); }
|
||||
static void ReplaceNet(Graph *g, std::shared_ptr<Net> n);
|
||||
static Status Import(const char *model_buf, Graph *graph);
|
||||
Status TrainNet(Node *optimizer, const std::vector<Expr *> &inputs);
|
||||
const std::vector<int> InputShape(int idx);
|
||||
const std::vector<int> OutputShape(int idx);
|
||||
std::unique_ptr<Graph> MakeMs();
|
||||
void Release() { pnet_.reset(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Net> pnet_;
|
||||
lite::Net *net_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include <vector>
|
||||
#include "include/api/net.h"
|
||||
#include "src/expression/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
Node *NodeImpl::Connect(lite::Node *lnode) {
|
||||
auto node = std::make_unique<Node>();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot allocate node";
|
||||
return nullptr;
|
||||
}
|
||||
if (lnode == nullptr) {
|
||||
MS_LOG(ERROR) << "lite node is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto pnode = node.release();
|
||||
auto impl = GetImpl(pnode);
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing implementation";
|
||||
return nullptr;
|
||||
}
|
||||
impl->set_node(lnode);
|
||||
lnode->set_impl(impl);
|
||||
return pnode;
|
||||
}
|
||||
namespace NN {
|
||||
std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type, int fmt) {
|
||||
auto type = static_cast<TypeId>(data_type);
|
||||
auto lite_node = lite::NN::Input(dims, type, fmt);
|
||||
return std::unique_ptr<Node>(NodeImpl::Connect(lite_node));
|
||||
}
|
||||
} // namespace NN
|
||||
} // namespace mindspore
|
|
@ -1,71 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "include/api/net.h"
|
||||
#include "include/api/cfg.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "src/expression/node.h"
|
||||
#include "src/expression/expr.h"
|
||||
|
||||
namespace mindspore {
|
||||
using lite::EXPR;
|
||||
class NodeSet {
|
||||
public:
|
||||
std::set<lite::Node *> set_;
|
||||
};
|
||||
|
||||
class Expr : public EXPR {
|
||||
public:
|
||||
static std::vector<EXPR *> convert(const std::vector<Expr *> &input) {
|
||||
std::vector<EXPR *> vec(input.size());
|
||||
(void)std::transform(input.begin(), input.end(), vec.begin(), [](Expr *e) { return reinterpret_cast<EXPR *>(e); });
|
||||
return vec;
|
||||
}
|
||||
static std::vector<Expr *> convert(const std::vector<EXPR *> &input) {
|
||||
std::vector<Expr *> vec(input.size());
|
||||
(void)std::transform(input.begin(), input.end(), vec.begin(), [](EXPR *e) { return reinterpret_cast<Expr *>(e); });
|
||||
return vec;
|
||||
}
|
||||
};
|
||||
|
||||
class MS_API NodeImpl {
|
||||
public:
|
||||
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) {
|
||||
auto in = Expr::convert(inputs);
|
||||
auto out = (*node_)(in);
|
||||
return Expr::convert(out);
|
||||
}
|
||||
lite::Node *node() { return node_; }
|
||||
void set_node(lite::Node *node) { node_ = node; }
|
||||
void set_pnode(Node *node) { pnode_ = node; }
|
||||
Node *pnode() { return pnode_; }
|
||||
static Node *Connect(lite::Node *lnode);
|
||||
static std::shared_ptr<NodeImpl> &GetImpl(Node *node) { return node->impl_; }
|
||||
|
||||
private:
|
||||
Node *pnode_{nullptr};
|
||||
lite::Node *node_{nullptr};
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/net.h"
|
||||
#include "src/litert/cxx_api/graph/graph_data.h"
|
||||
#include "src/litert/cxx_api/graph/net_data.h"
|
||||
|
||||
namespace mindspore {
|
||||
Graph::Graph() : graph_data_(nullptr) {}
|
||||
|
@ -27,15 +25,8 @@ Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_d
|
|||
|
||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
Graph::Graph(Graph::Type type) : graph_type_(type) {}
|
||||
|
||||
Graph::~Graph() {}
|
||||
|
||||
Graph::Graph(Net *net) : graph_type_(kExpressionGraph) {
|
||||
auto shared = std::make_shared<NetData>(net->shared_from_this());
|
||||
net_data_ = shared;
|
||||
}
|
||||
|
||||
Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
|
||||
|
||||
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/litert/cxx_api/graph/net_data.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
NetData::~NetData() { net_->impl_->Release(); }
|
||||
} // namespace mindspore
|
|
@ -1,35 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/api/net.h"
|
||||
|
||||
namespace mindspore {
|
||||
class NetData {
|
||||
public:
|
||||
explicit NetData(const std::shared_ptr<Net> &net) : net_(net) {}
|
||||
virtual ~NetData();
|
||||
void set_net(std::shared_ptr<Net> net) { net_ = net; }
|
||||
std::shared_ptr<Net> &net() { return net_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Net> net_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_
|
|
@ -30,7 +30,6 @@
|
|||
#if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug)
|
||||
#include "src/common/thread_utils.h"
|
||||
#endif
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include "src/litert/cxx_api/callback/callback_adapter.h"
|
||||
#include "src/litert/cxx_api/callback/callback_impl.h"
|
||||
#include "src/litert/cxx_api/model/model_impl.h"
|
||||
|
|
|
@ -79,14 +79,6 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt
|
|||
return proto_;
|
||||
}
|
||||
|
||||
ExpressionLoader CreateExpressionLoader(const ExpressionLoader &loader) {
|
||||
static ExpressionLoader loader_ = nullptr;
|
||||
if (loader != nullptr) {
|
||||
loader_ = loader;
|
||||
}
|
||||
return loader_;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug)
|
||||
Status ModelImpl::BuildAndRun(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
|
|
|
@ -49,9 +49,6 @@ typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_
|
|||
const std::shared_ptr<lite::InnerContext> &context);
|
||||
MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
|
||||
|
||||
using ExpressionLoader = std::function<Status(const char *, Graph *)>;
|
||||
MS_API ExpressionLoader CreateExpressionLoader(const ExpressionLoader &loader = nullptr);
|
||||
|
||||
namespace session {
|
||||
class Metrics;
|
||||
class TrainLoopCallBack;
|
||||
|
@ -106,7 +103,6 @@ class ModelImpl {
|
|||
|
||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
bool IsTrainModel();
|
||||
std::unique_ptr<Graph> BuildTrain(Node *optimizer, std::vector<Expr *> inputs);
|
||||
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
|
||||
Status SetLearningRate(float learning_rate);
|
||||
float GetLearningRate();
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "include/api/graph.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/model.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include "src/litert/cxx_api/graph/graph_data.h"
|
||||
#include "src/litert/cxx_api/model/model_impl.h"
|
||||
#include "src/litert/cxx_api/converters.h"
|
||||
|
@ -28,8 +27,6 @@
|
|||
#include "src/litert/lite_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::function<int(void *)> ExpressionCallback;
|
||||
|
||||
Key::Key(const char *dec_key, size_t key_len) {
|
||||
len = 0;
|
||||
if (key_len >= max_key_len) {
|
||||
|
@ -121,7 +118,6 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (graph->IsExecutable()) {
|
||||
auto model =
|
||||
std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true));
|
||||
if (model == nullptr) {
|
||||
|
@ -135,17 +131,6 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
}
|
||||
*graph = Graph(graph_data);
|
||||
return kSuccess;
|
||||
} else {
|
||||
auto loader = CreateExpressionLoader();
|
||||
if (loader == nullptr) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
delete[] model_buf;
|
||||
return kLiteError;
|
||||
}
|
||||
(void)loader(model_buf, graph);
|
||||
delete[] model_buf;
|
||||
return kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/net.h"
|
||||
#include "include/api/callback/callback.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "src/litert/cxx_api/model/model_impl.h"
|
||||
|
|
|
@ -18,34 +18,6 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/litert/cxx_api/model/model_impl.h"
|
||||
namespace mindspore {
|
||||
Status Model::Build(GraphCell lossGraphCell, Node *optimizer, std::vector<Expr *> inputs,
|
||||
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
std::stringstream err_msg;
|
||||
if (impl_ == nullptr) {
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
auto lossGraph = lossGraphCell.GetGraph();
|
||||
if (lossGraph == nullptr) {
|
||||
err_msg << "Invalid null graph";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kLiteNullptr, err_msg.str());
|
||||
}
|
||||
impl_->SetContext(model_context);
|
||||
impl_->SetConfig(train_cfg);
|
||||
impl_->SetGraph(lossGraph);
|
||||
auto graph = impl_->BuildTrain(optimizer, inputs);
|
||||
auto status = Build(GraphCell(*graph), model_context, train_cfg);
|
||||
if (status != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "Error " << status << " during model build";
|
||||
return status;
|
||||
}
|
||||
return kSuccess; // status
|
||||
}
|
||||
|
||||
Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
std::stringstream err_msg;
|
||||
|
|
|
@ -18,35 +18,7 @@
|
|||
#include "include/train/train_cfg.h"
|
||||
#include "src/litert/cxx_api/converters.h"
|
||||
#include "src/train/transfer_session.h"
|
||||
#include "src/litert/cxx_api/expression/node_impl.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
namespace mindspore {
|
||||
std::unique_ptr<Graph> ModelImpl::BuildTrain(Node *optimizer, std::vector<Expr *> inputs) {
|
||||
auto opt_impl = NodeImpl::GetImpl(optimizer);
|
||||
if (opt_impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing optimizer node implementation";
|
||||
return nullptr;
|
||||
}
|
||||
auto opt = opt_impl->node();
|
||||
auto in = Expr::convert(inputs);
|
||||
auto net_impl = NetImpl::GetImpl(graph_->net_data_->net().get());
|
||||
if (net_impl == nullptr) {
|
||||
MS_LOG(ERROR) << "missing net implementation";
|
||||
return nullptr;
|
||||
}
|
||||
auto trained_net = net_impl->net()->TrainNet(opt, in);
|
||||
if (trained_net == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to train network";
|
||||
return nullptr;
|
||||
}
|
||||
auto mgraph = net_impl->MakeMs();
|
||||
if (mgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to create graph";
|
||||
return nullptr;
|
||||
}
|
||||
return mgraph;
|
||||
}
|
||||
|
||||
Status ModelImpl::BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head) {
|
||||
const auto b_graph_data = backbone->graph_data_;
|
||||
const auto h_graph_data = head->graph_data_;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "include/api/serialization.h"
|
||||
#include "include/api/callback/callback.h"
|
||||
#include "include/api/metrics/metrics.h"
|
||||
#include "src/litert/cxx_api/expression/net_impl.h"
|
||||
#include "src/litert/cxx_api/converters.h"
|
||||
#include "src/litert/cxx_api/metrics/metrics_adapter.h"
|
||||
#include "src/litert/cxx_api/metrics/metrics_impl.h"
|
||||
|
|
|
@ -51,7 +51,5 @@ vae
|
|||
unified_api code_example
|
||||
train_lenet code_example
|
||||
train_lenet_java code_example
|
||||
lenet expression
|
||||
mobilenetv2 expression noarm32
|
||||
# LAST
|
||||
#test_resize inputShapes 16,10,10,1:16,10,10,1 0.5
|
||||
|
|
|
@ -14,13 +14,6 @@ set(TEST_SRC
|
|||
# add static securec link library
|
||||
include(${TOP_DIR}/cmake/dependency_securec.cmake)
|
||||
|
||||
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
||||
set(TEST_SRC
|
||||
${TEST_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/net_runner.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_executable(benchmark_train
|
||||
${TEST_SRC}
|
||||
${COMMON_SRC})
|
||||
|
|
|
@ -1,371 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/benchmark_train/net_runner.h"
|
||||
#include "tools/benchmark_train/net_train.h"
|
||||
#include <getopt.h>
|
||||
#include <malloc.h>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <utility>
|
||||
#include <chrono>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "include/api/callback/loss_monitor.h"
|
||||
#include "include/api/metrics/accuracy.h"
|
||||
#include "include/api/callback/ckpt_saver.h"
|
||||
#include "include/api/callback/train_accuracy.h"
|
||||
#include "include/api/callback/lr_scheduler.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
#include "include/dataset/vision_lite.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
#include "include/api/cfg.h"
|
||||
#include "include/api/net.h"
|
||||
|
||||
using mindspore::AccuracyMetrics;
|
||||
using mindspore::Model;
|
||||
using mindspore::TrainAccuracy;
|
||||
using mindspore::TrainCallBack;
|
||||
using mindspore::TrainCallBackData;
|
||||
using mindspore::dataset::Dataset;
|
||||
using mindspore::dataset::Mnist;
|
||||
using mindspore::dataset::SequentialSampler;
|
||||
using mindspore::dataset::TensorOperation;
|
||||
using mindspore::dataset::transforms::TypeCast;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
|
||||
constexpr int kNCHWCDim = 2;
|
||||
constexpr int kPrintTimes = 100;
|
||||
constexpr float kBetta1 = 0.9f;
|
||||
constexpr float kBetta2 = 0.999f;
|
||||
|
||||
class Rescaler : public mindspore::TrainCallBack {
|
||||
public:
|
||||
explicit Rescaler(float scale) : scale_(scale) {
|
||||
if (scale_ == 0) {
|
||||
scale_ = 1.0;
|
||||
}
|
||||
}
|
||||
~Rescaler() override = default;
|
||||
void StepBegin(const mindspore::TrainCallBackData &cb_data) override {
|
||||
auto inputs = cb_data.model_->GetInputs();
|
||||
auto *input_data = reinterpret_cast<float *>(inputs.at(0).MutableData());
|
||||
for (int k = 0; k < inputs.at(0).ElementNum(); k++) input_data[k] /= scale_;
|
||||
}
|
||||
|
||||
private:
|
||||
float scale_ = 1.0;
|
||||
};
|
||||
|
||||
/* This is an example of a user defined Callback to measure memory and latency of execution */
|
||||
class Measurement : public mindspore::TrainCallBack {
|
||||
public:
|
||||
explicit Measurement(unsigned int epochs)
|
||||
: time_avg_(std::chrono::duration<double, std::milli>(0)), epochs_(epochs) {}
|
||||
~Measurement() override = default;
|
||||
void EpochBegin(const mindspore::TrainCallBackData &cb_data) override {
|
||||
start_time_ = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
mindspore::CallbackRetValue EpochEnd(const mindspore::TrainCallBackData &cb_data) override {
|
||||
end_time_ = std::chrono::high_resolution_clock::now();
|
||||
auto time = std::chrono::duration<double, std::milli>(end_time_ - start_time_);
|
||||
time_avg_ += time;
|
||||
return mindspore::kContinue;
|
||||
}
|
||||
void End(const mindspore::TrainCallBackData &cb_data) override {
|
||||
if (epochs_ > 0) {
|
||||
std::cout << "AvgRunTime: " << time_avg_.count() / epochs_ << " ms" << std::endl;
|
||||
}
|
||||
|
||||
struct mallinfo info = mallinfo();
|
||||
std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl;
|
||||
}
|
||||
|
||||
private:
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> start_time_;
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> end_time_;
|
||||
std::chrono::duration<double, std::milli> time_avg_;
|
||||
unsigned int epochs_;
|
||||
};
|
||||
|
||||
NetRunner::~NetRunner() {
|
||||
if (model_ != nullptr) {
|
||||
delete model_;
|
||||
}
|
||||
if (graph_ != nullptr) {
|
||||
delete graph_;
|
||||
}
|
||||
}
|
||||
|
||||
mindspore::Status NetRunner::InitAndFigureInputs() {
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(enable_fp16_);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
graph_ = new (std::nothrow) mindspore::Graph(mindspore::Graph::Type::kExpressionGraph);
|
||||
if (graph_ == nullptr) {
|
||||
std::cout << "Cannot allocate graph" << std::endl;
|
||||
return mindspore::kLiteMemoryFailed;
|
||||
}
|
||||
auto status = mindspore::Serialization::Load(ms_file_, mindspore::kMindIR, graph_);
|
||||
if (status != mindspore::kSuccess) {
|
||||
std::cout << "Error " << status << " during serialization of graph " << ms_file_;
|
||||
return status;
|
||||
}
|
||||
auto net = std::make_unique<mindspore::Net>(*graph_);
|
||||
auto input_shape = net->InputShape(0);
|
||||
auto label_shape = net->OutputShape(0);
|
||||
auto inputM = mindspore::NN::Input(input_shape);
|
||||
auto labelM = mindspore::NN::Input(label_shape);
|
||||
auto label = labelM->Create("label");
|
||||
auto input = inputM->Create("input");
|
||||
|
||||
auto cfg = std::make_shared<mindspore::TrainCfg>();
|
||||
if (enable_fp16_) {
|
||||
cfg.get()->optimization_level_ = mindspore::kO2;
|
||||
}
|
||||
|
||||
model_ = new (std::nothrow) mindspore::Model();
|
||||
if (model_ == nullptr) {
|
||||
std::cout << "model allocation failed" << std::endl;
|
||||
return mindspore::kLiteMemoryFailed;
|
||||
}
|
||||
mindspore::SoftMaxCrossEntropyCfg softmax_ce_cfg;
|
||||
softmax_ce_cfg.reduction = "none";
|
||||
auto netWithLoss = mindspore::NN::GraphWithLoss(graph_, mindspore::NN::SoftmaxCrossEntropy(softmax_ce_cfg));
|
||||
mindspore::AdamConfig AdamCfg;
|
||||
AdamCfg.beta1_ = kBetta1;
|
||||
AdamCfg.beta2_ = kBetta2;
|
||||
AdamCfg.eps_ = 1e-8;
|
||||
AdamCfg.learning_rate_ = 1e-2;
|
||||
auto optimizer = mindspore::NN::Adam(net->trainable_params(), AdamCfg);
|
||||
status = model_->Build(mindspore::GraphCell(*netWithLoss), optimizer, {input, label}, context, cfg);
|
||||
if (status != mindspore::kSuccess) {
|
||||
std::cout << "Error " << status << " during build of model " << ms_file_ << std::endl;
|
||||
return status;
|
||||
}
|
||||
delete graph_;
|
||||
graph_ = nullptr;
|
||||
auto inputs = model_->GetInputs();
|
||||
if (inputs.size() < 1) {
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
auto nhwc_input_dims = inputs.at(0).Shape();
|
||||
batch_size_ = nhwc_input_dims.at(0);
|
||||
h_ = nhwc_input_dims.at(1);
|
||||
w_ = nhwc_input_dims.at(kNCHWCDim);
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
|
||||
int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
|
||||
std::cout << "================ Comparing Forward Output data ================" << std::endl;
|
||||
float total_bias = 0;
|
||||
int total_size = 0;
|
||||
bool has_error = false;
|
||||
int i = 1;
|
||||
for (auto &tensor : outputs) {
|
||||
std::cout << "output is tensor " << tensor.Name() << "\n";
|
||||
auto output = tensor.Data();
|
||||
size_t size;
|
||||
std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
|
||||
auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size));
|
||||
if (bin_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "ReadFile return nullptr";
|
||||
std::cout << "ReadFile return nullptr" << std::endl;
|
||||
return mindspore::kLiteNullptr;
|
||||
}
|
||||
if (size != tensor.DataSize()) {
|
||||
MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
|
||||
<< ", read size: " << size;
|
||||
std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
|
||||
<< ", read size: " << size << std::endl;
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
|
||||
reinterpret_cast<const float *>(output.get()));
|
||||
if (bias >= 0) {
|
||||
total_bias += bias;
|
||||
total_size++;
|
||||
} else {
|
||||
has_error = true;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
if (!has_error) {
|
||||
float mean_bias;
|
||||
if (total_size != 0) {
|
||||
mean_bias = total_bias / total_size * kPrintTimes;
|
||||
} else {
|
||||
mean_bias = 0;
|
||||
}
|
||||
|
||||
std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
|
||||
<< " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
|
||||
std::cout << "=======================================================" << std::endl << std::endl;
|
||||
|
||||
if (mean_bias > this->flags_->accuracy_threshold_) {
|
||||
MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
|
||||
std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
|
||||
return mindspore::kLiteError;
|
||||
} else {
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error in CompareData";
|
||||
std::cout << "Error in CompareData" << std::endl;
|
||||
std::cout << "=======================================================" << std::endl << std::endl;
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
void NetRunner::CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out) {
|
||||
constexpr int kPrintLen = 4;
|
||||
int tensor_size = tensor.ElementNum();
|
||||
const void *data = tensor.Data().get();
|
||||
const float *fdata = reinterpret_cast<const float *>(data);
|
||||
mindspore::DataType type = tensor.DataType();
|
||||
std::cout << node_type << " " << in_out << id << std::endl;
|
||||
std::cout << "tensor name: " << tensor.Name() << std::endl;
|
||||
if ((tensor_size) == 0 || (data == nullptr)) {
|
||||
std::cout << "Empty tensor" << std::endl;
|
||||
return;
|
||||
}
|
||||
switch (type) {
|
||||
case mindspore::DataType::kNumberTypeFloat32:
|
||||
std::cout << "sum=" << mindspore::lite::TensorSum<float>(data, tensor_size) << std::endl;
|
||||
std::cout << "data: ";
|
||||
for (int i = 0; i <= kPrintLen && i < tensor_size; i++) {
|
||||
std::cout << static_cast<float>(fdata[i]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
break;
|
||||
case mindspore::DataType::kNumberTypeInt32:
|
||||
std::cout << "sum=" << mindspore::lite::TensorSum<int>(data, tensor_size) << std::endl;
|
||||
break;
|
||||
default:
|
||||
std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int NetRunner::InitCallbackParameter() {
|
||||
// after callback
|
||||
after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs,
|
||||
const std::vector<mindspore::MSTensor> &after_outputs,
|
||||
const mindspore::MSCallBackParam &call_param) {
|
||||
if (after_inputs.empty()) {
|
||||
MS_LOG(INFO) << "The num of after inputs is empty";
|
||||
}
|
||||
if (after_outputs.empty()) {
|
||||
MS_LOG(INFO) << "The num of after outputs is empty";
|
||||
}
|
||||
if (flags_->layer_checksum_) {
|
||||
for (size_t i = 0; i < after_inputs.size(); i++) {
|
||||
CheckSum(after_inputs.at(i), call_param.node_type, i, "in");
|
||||
}
|
||||
for (size_t i = 0; i < after_outputs.size(); i++) {
|
||||
CheckSum(after_outputs.at(i), call_param.node_type, i, "out");
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return false;
|
||||
}
|
||||
|
||||
int NetRunner::RunOnce() {
|
||||
auto inputs = model_->GetInputs();
|
||||
std::vector<mindspore::MSTensor> output;
|
||||
auto status = LoadInput(&inputs);
|
||||
if (status != mindspore::kSuccess) {
|
||||
std::cout << "cannot load data";
|
||||
return status;
|
||||
}
|
||||
model_->SetTrainMode(true);
|
||||
model_->RunStep(nullptr, nullptr);
|
||||
model_->SetTrainMode(false);
|
||||
model_->Predict(inputs, &output, nullptr, nullptr);
|
||||
return CompareOutput(output);
|
||||
}
|
||||
|
||||
int NetRunner::LoadInput(std::vector<mindspore::MSTensor> *ms_inputs) {
|
||||
auto status = ReadInputFile(ms_inputs);
|
||||
if (status != mindspore::kSuccess) {
|
||||
std::cout << "Read Input File error, " << status << std::endl;
|
||||
MS_LOG(ERROR) << "Read Input File error, " << status;
|
||||
return status;
|
||||
}
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
|
||||
int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) {
|
||||
if (ms_inputs->empty()) {
|
||||
std::cout << "no inputs to input" << std::endl;
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < ms_inputs->size(); i++) {
|
||||
auto cur_tensor = ms_inputs->at(i);
|
||||
if (cur_tensor == nullptr) {
|
||||
std::cout << "empty tensor " << i << std::endl;
|
||||
MS_LOG(ERROR) << "empty tensor " << i;
|
||||
}
|
||||
size_t size;
|
||||
std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
|
||||
auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size));
|
||||
if (bin_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "ReadFile return nullptr";
|
||||
std::cout << "ReadFile return nullptr" << std::endl;
|
||||
return mindspore::kLiteNullptr;
|
||||
}
|
||||
auto tensor_data_size = cur_tensor.DataSize();
|
||||
if (size != tensor_data_size) {
|
||||
std::cout << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
|
||||
<< " ,file_name: " << file_name.c_str() << std::endl;
|
||||
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
|
||||
<< " ,file_name: " << file_name.c_str();
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
auto input_data = cur_tensor.MutableData();
|
||||
memcpy(input_data, bin_buf.get(), tensor_data_size);
|
||||
}
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
|
||||
int NetRunner::Main() {
|
||||
ms_file_ = flags_->model_file_;
|
||||
InitCallbackParameter();
|
||||
auto status = InitAndFigureInputs();
|
||||
if (status != mindspore::kSuccess) {
|
||||
std::cout << "failed to initialize network" << std::endl;
|
||||
return status.StatusCode();
|
||||
}
|
||||
return RunOnce();
|
||||
}
|
||||
|
||||
int CallBack(mindspore::lite::NetTrainFlags *flags) {
|
||||
NetRunner nr(flags);
|
||||
return nr.Main();
|
||||
}
|
||||
|
||||
int init = mindspore::lite::NetTrain::SetNr(CallBack);
|
|
@ -1,81 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <iomanip>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/metrics/accuracy.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
|
||||
using mindspore::AccuracyMetrics;
|
||||
using mindspore::dataset::Dataset;
|
||||
|
||||
namespace mindspore::lite {
|
||||
class NetTrainFlags;
|
||||
}
|
||||
|
||||
class NetRunner {
|
||||
public:
|
||||
int Main();
|
||||
explicit NetRunner(mindspore::lite::NetTrainFlags *flags) : flags_(flags) {}
|
||||
bool ReadArgs(int argc, int8_t *argv[]);
|
||||
virtual ~NetRunner();
|
||||
|
||||
private:
|
||||
void Usage();
|
||||
mindspore::Status InitAndFigureInputs();
|
||||
void CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out);
|
||||
int InitCallbackParameter();
|
||||
int TrainLoop();
|
||||
float CalculateAccuracy(int max_tests = 0);
|
||||
float GetLoss() const;
|
||||
int RunOnce();
|
||||
int CompareOutput(const std::vector<mindspore::MSTensor> &outputs);
|
||||
int LoadInput(std::vector<mindspore::MSTensor> *ms_inputs);
|
||||
int ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs);
|
||||
|
||||
mindspore::Model *model_ = nullptr;
|
||||
mindspore::Graph *graph_ = nullptr;
|
||||
|
||||
std::shared_ptr<Dataset> train_ds_;
|
||||
std::shared_ptr<Dataset> test_ds_;
|
||||
std::shared_ptr<AccuracyMetrics> acc_metrics_;
|
||||
|
||||
std::string ms_file_ = "";
|
||||
std::string data_dir_ = "";
|
||||
unsigned int epochs_ = 10;
|
||||
bool verbose_ = false;
|
||||
bool enable_fp16_ = false;
|
||||
int virtual_batch_ = -1;
|
||||
int save_checkpoint_ = 0;
|
||||
int batch_size_ = 32;
|
||||
int h_ = 32;
|
||||
int w_ = 32;
|
||||
mindspore::lite::NetTrainFlags *flags_{nullptr};
|
||||
mindspore::MSKernelCallBack after_call_back_;
|
||||
};
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_
|
|
@ -24,7 +24,6 @@
|
|||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "tools/benchmark_train/net_runner.h"
|
||||
#include "src/common/common.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
|
Loading…
Reference in New Issue