From 35ac798bfa4805c90e5a16b4c40d0c455aad60f0 Mon Sep 17 00:00:00 2001 From: yoni Date: Thu, 14 Apr 2022 15:22:36 +0300 Subject: [PATCH] lite expression --- include/api/cfg.h | 7 +- include/api/graph.h | 15 + include/api/model.h | 14 + include/api/net.h | 141 +++++++ .../nnacl/infer/conv2d_grad_filter_infer.c | 26 +- .../nnacl/infer/conv2d_grad_input_infer.c | 16 +- .../cpu/kernel/nnacl/infer/conv2d_infer.c | 2 +- .../plugin/device/cpu/kernel/nnacl/op_base.h | 3 + mindspore/core/load_mindir/load_model.cc | 2 +- mindspore/lite/src/CMakeLists.txt | 25 ++ mindspore/lite/src/cxx_api/expression/net.cc | 145 +++++++ .../lite/src/cxx_api/expression/net_impl.cc | 222 +++++++++++ .../lite/src/cxx_api/expression/net_impl.h | 95 +++++ .../lite/src/cxx_api/expression/node_impl.cc | 50 +++ .../lite/src/cxx_api/expression/node_impl.h | 71 ++++ mindspore/lite/src/cxx_api/graph/graph.cc | 9 + mindspore/lite/src/cxx_api/graph/net_data.cc | 21 + mindspore/lite/src/cxx_api/graph/net_data.h | 35 ++ mindspore/lite/src/cxx_api/model/model.cc | 4 + .../lite/src/cxx_api/model/model_impl.cc | 8 + mindspore/lite/src/cxx_api/model/model_impl.h | 4 + mindspore/lite/src/cxx_api/serialization.cc | 39 +- mindspore/lite/src/cxx_api/train/model.cc | 32 +- .../lite/src/cxx_api/train/model_impl.cc | 29 ++ mindspore/lite/src/expression/cfg.h | 68 ++++ mindspore/lite/src/expression/export.cc | 76 ++++ mindspore/lite/src/expression/export.h | 52 +++ mindspore/lite/src/expression/expr.cc | 98 +++++ mindspore/lite/src/expression/expr.h | 70 ++++ mindspore/lite/src/expression/import.cc | 180 +++++++++ mindspore/lite/src/expression/import.h | 61 +++ mindspore/lite/src/expression/net.cc | 268 +++++++++++++ mindspore/lite/src/expression/net.h | 114 ++++++ mindspore/lite/src/expression/node.cc | 271 +++++++++++++ mindspore/lite/src/expression/node.h | 156 ++++++++ mindspore/lite/src/expression/ops.cc | 66 ++++ mindspore/lite/src/expression/ops.h | 69 ++++ .../lite/src/expression/ops/activation.cc | 133 +++++++ .../lite/src/expression/ops/activation.h | 44 +++ mindspore/lite/src/expression/ops/adam.cc | 142 +++++++ mindspore/lite/src/expression/ops/adam.h | 48 +++ mindspore/lite/src/expression/ops/addn.cc | 42 ++ mindspore/lite/src/expression/ops/addn.h | 34 ++ .../lite/src/expression/ops/arithmetic.cc | 223 +++++++++++ .../lite/src/expression/ops/arithmetic.h | 75 ++++ .../src/expression/ops/arithmetic_self.cc | 72 ++++ .../lite/src/expression/ops/arithmetic_self.h | 46 +++ mindspore/lite/src/expression/ops/assign.cc | 60 +++ mindspore/lite/src/expression/ops/assign.h | 35 ++ .../lite/src/expression/ops/batchnorm.cc | 135 +++++++ mindspore/lite/src/expression/ops/batchnorm.h | 43 ++ mindspore/lite/src/expression/ops/biasadd.cc | 93 +++++ mindspore/lite/src/expression/ops/biasadd.h | 44 +++ mindspore/lite/src/expression/ops/conv.cc | 241 ++++++++++++ mindspore/lite/src/expression/ops/conv.h | 58 +++ mindspore/lite/src/expression/ops/dense.cc | 151 +++++++ mindspore/lite/src/expression/ops/dense.h | 44 +++ mindspore/lite/src/expression/ops/depend.cc | 43 ++ mindspore/lite/src/expression/ops/depend.h | 37 ++ mindspore/lite/src/expression/ops/dropout.cc | 91 +++++ mindspore/lite/src/expression/ops/dropout.h | 42 ++ mindspore/lite/src/expression/ops/flatten.cc | 71 ++++ mindspore/lite/src/expression/ops/flatten.h | 36 ++ mindspore/lite/src/expression/ops/pooling.cc | 215 ++++++++++ mindspore/lite/src/expression/ops/pooling.h | 74 ++++ mindspore/lite/src/expression/ops/reduce.cc | 126 ++++++ mindspore/lite/src/expression/ops/reduce.h | 42 ++ mindspore/lite/src/expression/ops/reshape.cc | 74 ++++ mindspore/lite/src/expression/ops/reshape.h | 37 ++ mindspore/lite/src/expression/ops/softmax.cc | 119 ++++++ mindspore/lite/src/expression/ops/softmax.h | 39 ++ .../lite/src/expression/ops/softmaxCE.cc | 93 +++++ mindspore/lite/src/expression/ops/softmaxCE.h | 47 +++ mindspore/lite/src/expression/ops/tile.cc | 62 +++ mindspore/lite/src/expression/ops/tile.h | 40 ++ .../lite/src/expression/ops/transpose.cc | 88 +++++ mindspore/lite/src/expression/ops/transpose.h | 59 +++ mindspore/lite/src/expression/ops_utils.cc | 275 +++++++++++++ mindspore/lite/src/expression/ops_utils.h | 69 ++++ mindspore/lite/src/expression/param.cc | 70 ++++ mindspore/lite/src/expression/param.h | 60 +++ mindspore/lite/src/expression/sequential.cc | 30 ++ mindspore/lite/src/expression/sequential.h | 32 ++ .../kernel/cpu/fp32/convolution_1x1_fp32.cc | 3 +- .../cpu/fp32/convolution_depthwise_fp32.cc | 8 +- .../lite/test/config_level0/cropped_size.cfg | 2 +- .../test/config_level0/models_ms_train.cfg | 2 + .../lite/test/config_level1/cropped_size.cfg | 2 +- .../lite/test/st/scripts/run_net_train.sh | 60 ++- .../lite/tools/benchmark_train/CMakeLists.txt | 17 +- .../lite/tools/benchmark_train/net_runner.cc | 371 ++++++++++++++++++ .../lite/tools/benchmark_train/net_runner.h | 81 ++++ .../lite/tools/benchmark_train/net_train.cc | 20 +- .../lite/tools/benchmark_train/net_train.h | 97 +++-- .../lite/tools/converter/converter_flags.cc | 22 +- .../lite/tools/converter/converter_flags.h | 9 +- 96 files changed, 6882 insertions(+), 110 deletions(-) create mode 100644 include/api/net.h create mode 100644 mindspore/lite/src/cxx_api/expression/net.cc create mode 100644 mindspore/lite/src/cxx_api/expression/net_impl.cc create mode 100644 mindspore/lite/src/cxx_api/expression/net_impl.h create mode 100644 mindspore/lite/src/cxx_api/expression/node_impl.cc create mode 100644 mindspore/lite/src/cxx_api/expression/node_impl.h create mode 100644 mindspore/lite/src/cxx_api/graph/net_data.cc create mode 100644 mindspore/lite/src/cxx_api/graph/net_data.h create mode 100644 mindspore/lite/src/expression/cfg.h create mode 100644 mindspore/lite/src/expression/export.cc create mode 100644 mindspore/lite/src/expression/export.h create mode 100644 mindspore/lite/src/expression/expr.cc create mode 100644 mindspore/lite/src/expression/expr.h create mode 100644 mindspore/lite/src/expression/import.cc create mode 100644 mindspore/lite/src/expression/import.h create mode 100644 mindspore/lite/src/expression/net.cc create mode 100644 mindspore/lite/src/expression/net.h create mode 100644 mindspore/lite/src/expression/node.cc create mode 100644 mindspore/lite/src/expression/node.h create mode 100644 mindspore/lite/src/expression/ops.cc create mode 100644 mindspore/lite/src/expression/ops.h create mode 100644 mindspore/lite/src/expression/ops/activation.cc create mode 100644 mindspore/lite/src/expression/ops/activation.h create mode 100644 mindspore/lite/src/expression/ops/adam.cc create mode 100644 mindspore/lite/src/expression/ops/adam.h create mode 100644 mindspore/lite/src/expression/ops/addn.cc create mode 100644 mindspore/lite/src/expression/ops/addn.h create mode 100644 mindspore/lite/src/expression/ops/arithmetic.cc create mode 100644 mindspore/lite/src/expression/ops/arithmetic.h create mode 100644 mindspore/lite/src/expression/ops/arithmetic_self.cc create mode 100644 mindspore/lite/src/expression/ops/arithmetic_self.h create mode 100644 mindspore/lite/src/expression/ops/assign.cc create mode 100644 mindspore/lite/src/expression/ops/assign.h create mode 100644 mindspore/lite/src/expression/ops/batchnorm.cc create mode 100644 mindspore/lite/src/expression/ops/batchnorm.h create mode 100644 mindspore/lite/src/expression/ops/biasadd.cc create mode 100644 mindspore/lite/src/expression/ops/biasadd.h create mode 100644 mindspore/lite/src/expression/ops/conv.cc create mode 100644 mindspore/lite/src/expression/ops/conv.h create mode 100644 mindspore/lite/src/expression/ops/dense.cc create mode 100644 mindspore/lite/src/expression/ops/dense.h create mode 100644 mindspore/lite/src/expression/ops/depend.cc create mode 100644 mindspore/lite/src/expression/ops/depend.h create mode 100644 mindspore/lite/src/expression/ops/dropout.cc create mode 100644 mindspore/lite/src/expression/ops/dropout.h create mode 100644 mindspore/lite/src/expression/ops/flatten.cc create mode 100644 mindspore/lite/src/expression/ops/flatten.h create mode 100644 mindspore/lite/src/expression/ops/pooling.cc create mode 100644 mindspore/lite/src/expression/ops/pooling.h create mode 100644 mindspore/lite/src/expression/ops/reduce.cc create mode 100644 mindspore/lite/src/expression/ops/reduce.h create mode 100644 mindspore/lite/src/expression/ops/reshape.cc create mode 100644 mindspore/lite/src/expression/ops/reshape.h create mode 100644 mindspore/lite/src/expression/ops/softmax.cc create mode 100644 mindspore/lite/src/expression/ops/softmax.h create mode 100644 mindspore/lite/src/expression/ops/softmaxCE.cc create mode 100644 mindspore/lite/src/expression/ops/softmaxCE.h create mode 100644 mindspore/lite/src/expression/ops/tile.cc create mode 100644 mindspore/lite/src/expression/ops/tile.h create mode 100644 mindspore/lite/src/expression/ops/transpose.cc create mode 100644 mindspore/lite/src/expression/ops/transpose.h create mode 100644 mindspore/lite/src/expression/ops_utils.cc create mode 100644 mindspore/lite/src/expression/ops_utils.h create mode 100644 mindspore/lite/src/expression/param.cc create mode 100644 mindspore/lite/src/expression/param.h create mode 100644 mindspore/lite/src/expression/sequential.cc create mode 100644 mindspore/lite/src/expression/sequential.h create mode 100644 mindspore/lite/tools/benchmark_train/net_runner.cc create mode 100644 mindspore/lite/tools/benchmark_train/net_runner.h diff --git a/include/api/cfg.h b/include/api/cfg.h index d2bc7b424d6..22b640a9803 100644 --- a/include/api/cfg.h +++ b/include/api/cfg.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,13 +25,13 @@ #include "include/api/types.h" namespace mindspore { - +constexpr int iter_th = 1000; class MixPrecisionCfg { public: MixPrecisionCfg() { this->dynamic_loss_scale_ = false; this->loss_scale_ = 128.0f; - this->num_of_not_nan_iter_th_ = 1000; + this->num_of_not_nan_iter_th_ = iter_th; } ~MixPrecisionCfg() = default; @@ -53,6 +53,5 @@ class TrainCfg { MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */ bool accumulate_gradients_ = false; }; - } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CFG_H diff --git a/include/api/graph.h b/include/api/graph.h index f25a6217f32..05548890cc8 100644 --- a/include/api/graph.h +++ b/include/api/graph.h @@ -24,23 +24,38 @@ #include "include/api/types.h" namespace mindspore { +class NetData; +class Net; + class MS_API Graph { public: class GraphData; + enum Type : uint32_t { + kExpressionGraph = 0, ///< graph as expression - can auto grad + kExecutableGraph = 1, ///< graph is loaded as is + kUnknownTypeGraph = 0xffffffff + }; Graph(); explicit Graph(const std::shared_ptr &graph_data); explicit Graph(std::shared_ptr &&graph_data); explicit Graph(std::nullptr_t); ~Graph(); + explicit Graph(Type executable); + explicit Graph(Net *net); enum ModelType ModelType() const; bool operator==(std::nullptr_t) const; bool operator!=(std::nullptr_t) const; + bool IsExecutable() { return graph_type_ == kExecutableGraph; } private: friend class GraphCell; friend class ModelImpl; + friend class NetImpl; + friend class Model; std::shared_ptr graph_data_; + std::shared_ptr net_data_; + Type graph_type_ = kExecutableGraph; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_GRAPH_H diff --git a/include/api/model.h b/include/api/model.h index b0ac1fff379..5e6d3c8325c 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -33,6 +33,9 @@ namespace mindspore { class ModelImpl; class Metrics; +class Net; +class Node; +class Expr; namespace dataset { class Dataset; @@ -109,6 +112,17 @@ class MS_API Model { Status Build(GraphCell graph, const std::shared_ptr &model_context = nullptr, const std::shared_ptr &train_cfg = nullptr); + /// \brief Build train model + /// + /// \param[in] graph A forward network + /// \param[in] optimizer An optimizer node + /// \param[in] inputs Inputs expression for the trained network (ex: input, label ) + /// \param[in] model_contex A context used to store options during execution. + /// \param[in] train_cfg A config used by training + /// \return Status + + Status Build(GraphCell graph, Node *optimizer, std::vector inputs, + const std::shared_ptr &model_context, const std::shared_ptr &train_cfg); /// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable /// /// \param[in] backbone The static, non-learnable part of the graph diff --git a/include/api/net.h b/include/api/net.h new file mode 100644 index 00000000000..c7a3a9b0070 --- /dev/null +++ b/include/api/net.h @@ -0,0 +1,141 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INCLUDE_API_NET_H +#define MINDSPORE_INCLUDE_API_NET_H + +#include +#include +#include +#include +#include "include/api/types.h" +#include "include/api/data_type.h" +#include "include/api/cfg.h" + +namespace mindspore { +/// \brief Register node or sub network +#define REG(_name) Register(_name, #_name) + +class Expr; +class NodeImpl; +class NetImpl; +class NodeSet; +class Graph; +class NetData; + +class NetBase { + public: + NetBase() = default; + virtual std::vector operator()(const std::vector &inputs) = 0; + virtual uint32_t type() = 0; +}; + +class Node : public NetBase { + public: + Node(); + virtual ~Node(); + /// \brief Create output expression from node + + /// \param[in] name Name of input (like "labels" etc.) + /// + /// \return Expression + Expr *Create(std::string name); + /// \brief Run node on inputs. This operator is used in Net::construct() + /// + /// \param[in] inputs Inputs expression for the node. + /// \return Output node expression vector + std::vector operator()(const std::vector &inputs) override; + uint32_t type() final; + + private: + friend NodeImpl; + std::shared_ptr impl_ = nullptr; +}; + +class Net : public NetBase, public std::enable_shared_from_this { + public: + Net(); + virtual ~Net(); + explicit Net(std::string name); + explicit Net(const Graph &g); + /// \brief Define the relation between network inputs and outputs + /// + /// \param[in] inputs expression vector + /// + /// \return expression vector + + virtual std::vector construct(const std::vector &inputs); + /// \brief Addition operation + /// + /// \param[in] inputs Two elements to add + /// + /// \return expression vector (single element) + + /// \brief Execution operator. Connect inputs to outputs via user defined construct + /// + /// \return expression vector + + std::vector operator()(const std::vector &inputs); + void Register(Net *net, std::string &&name); + void Register(Node *node, std::string &&name); + /// \brief Find the trainable params for the trained network + /// + /// \return NodeSet for all trainable nodes + std::shared_ptr trainable_params(); + virtual void Add(NetBase *element); + /// \brief Input shape + /// + /// \param[in] idx input index + /// + /// \return Specific input shape vector + const std::vector InputShape(int idx); + /// \brief Output shape + /// + /// \param[in] idx Output index + /// + /// \return Specific output shape vector + const std::vector OutputShape(int idx); + uint32_t type() final; + + private: + friend NetImpl; + friend NetData; + std::shared_ptr impl_; +}; + +class SoftMaxCrossEntropyCfg { + public: + std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */ +}; + +class AdamConfig { + public: + float learning_rate_ = 1e-3; + float beta1_ = 0.9; + float beta2_ = 0.999; + float eps_ = 1e-08; + bool use_nesterov_ = false; +}; + +namespace NN { +Net *NetWithLoss(Net *net, Node *loss); +Graph *GraphWithLoss(Graph *g, Node *loss); +Node *Adam(std::shared_ptr learn, const AdamConfig &cfg); +Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); +std::unique_ptr Input(std::vector dims, DataType data_type = DataType::kNumberTypeFloat32, int fmt = NHWC); +}; // namespace NN +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_NET_H diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_filter_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_filter_infer.c index c02ba325a62..5e880f39f68 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_filter_infer.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_filter_infer.c @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "nnacl/infer/conv2d_grad_filter_infer.h" #include "nnacl/infer/infer_register.h" @@ -26,28 +27,33 @@ int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, if (inputs_size < 3 || outputs_size != 1) { return NNACL_ERR; } - if (inputs[0]->format_ != Format_NHWC || inputs[1]->format_ != Format_NHWC) { + if (inputs[FIRST_INPUT]->format_ != Format_NHWC || inputs[SECOND_INPUT]->format_ != Format_NHWC) { return NNACL_FORMAT_ERROR; } - SetDataTypeFormat(outputs[0], inputs[0]); + SetDataTypeFormat(outputs[FIRST_INPUT], inputs[FIRST_INPUT]); - if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { + if (inputs[THIRD_INPUT]->shape_size_ < DIMENSION_1D || inputs[THIRD_INPUT]->data_ == NULL) { return NNACL_ERR; } - if (inputs[2]->shape_[0] < 0) { + if (inputs[THIRD_INPUT]->shape_[kNCHW_N] < 0) { return NNACL_ERR; } - size_t filter_shape_size = (size_t)(inputs[2]->shape_[0]); - if (filter_shape_size != 4) { + size_t filter_shape_size = (size_t)(inputs[THIRD_INPUT]->shape_[kNCHW_N]); + if (filter_shape_size != DIMENSION_4D) { return NNACL_ERR; } int filter_shape[MAX_SHAPE_SIZE]; - const int nchw2nhwc[4] = {0, 2, 3, 1}; - for (size_t i = 0; i < filter_shape_size; i++) { - filter_shape[i] = *((int *)(inputs[2]->data_) + nchw2nhwc[i]); + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < filter_shape_size; i++) { + filter_shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(filter_shape, inputs[THIRD_INPUT]->data_, filter_shape_size * sizeof(int)); + } else { + return NNACL_ERR; } - SetShapeArray(outputs[0], filter_shape, filter_shape_size); return NNACL_OK; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_input_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_input_infer.c index 60609c6f0e4..2651deb5970 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_input_infer.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_grad_input_infer.c @@ -37,20 +37,26 @@ int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, } SetDataTypeFormat(out, in0); - if (inputs[2]->shape_size_ < 1 || inputs[2]->data_ == NULL) { + if (inputs[THIRD_INPUT]->shape_size_ < 1 || inputs[THIRD_INPUT]->data_ == NULL) { return NNACL_ERR; } size_t data_size = (size_t)inputs[2]->shape_[0]; if (data_size != 4) { return NNACL_ERR; } + int shape[MAX_SHAPE_SIZE]; - const int nchw2nhwc[4] = {0, 2, 3, 1}; - for (size_t i = 0; i < data_size; i++) { - shape[i] = *((int *)(inputs[2]->data_) + nchw2nhwc[i]); + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[4] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < data_size; i++) { + shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(shape, inputs[THIRD_INPUT]->data_, data_size * sizeof(int)); + } else { + return NNACL_ERR; } SetShapeArray(out, shape, data_size); - return NNACL_OK; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_infer.c index 9b189537efd..0120d8d96a3 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_infer.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/conv2d_infer.c @@ -157,7 +157,7 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * param->output_h_ = out_shape[DIMENSION_1D]; param->output_w_ = out_shape[DIMENSION_2D]; param->output_channel_ = out_shape[DIMENSION_3D]; - + param->out_format_ = out_tensor->format_; return NNACL_OK; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h index 7a28a0e33bd..74f5dfa3084 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h @@ -25,6 +25,7 @@ #include "nnacl/intrinsics/ms_simd_instructions.h" +#define C0NUM 0 #define C1NUM 1 #define C2NUM 2 #define C3NUM 3 @@ -33,6 +34,8 @@ #define C6NUM 6 #define C7NUM 7 #define C8NUM 8 +#define C9NUM 9 +#define C10NUM 10 #define C12NUM 12 #define C13NUM 13 #define C16NUM 16 diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index bedabd4d147..c31f9db07d8 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -235,7 +235,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) { _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX); #else if (!realpath(file_name.c_str(), abs_path_buff)) { - MS_LOG(ERROR) << "Load MindIR get absolute path failed"; + MS_LOG(ERROR) << "Load MindIR get absolute path of " << file_name << " failed"; } #endif // Read graph diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index f88e43e815a..904100e1427 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -275,9 +275,34 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc ) endif() + + +file(GLOB CXX_API_EXPRESSION + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/expression/*.cc + ) + +file(GLOB EXPRESSION_OPS + ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops/*.cc + ) + +set(EXPRESSION_SRC + ${CXX_API_EXPRESSION} + ${CMAKE_CURRENT_SOURCE_DIR}/expression/export.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/expr.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/import.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/net.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/node.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/param.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expression/sequential.cc + ${EXPRESSION_OPS} + ) + set(TRAIN_SRC ${API_TRAIN_SRC} ${TRAIN_SRC_WITH_MD} + ${EXPRESSION_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc diff --git a/mindspore/lite/src/cxx_api/expression/net.cc b/mindspore/lite/src/cxx_api/expression/net.cc new file mode 100644 index 00000000000..9ff4a53d37e --- /dev/null +++ b/mindspore/lite/src/cxx_api/expression/net.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/api/net.h" +#include "include/api/status.h" +#include "src/cxx_api/expression/node_impl.h" +#include "src/cxx_api/expression/net_impl.h" +#include "src/expression/ops.h" +#include "src/expression/cfg.h" + +namespace mindspore { +uint32_t Node::type() { return kNodeType; } + +std::vector Node::operator()(const std::vector &inputs) { + auto in = Expr::convert(inputs); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "empty implementation"; + return {}; + } + if (impl_->node() == nullptr) { + MS_LOG(ERROR) << "expression node is not attached"; + return {}; + } + auto out = impl_->node()->construct(in); + return Expr::convert(out); +} + +Expr *Node::Create(std::string name) { + auto expr = impl_->node()->create(name); + return reinterpret_cast(expr); +} + +Node::Node() { + auto impl = std::make_shared(); + impl_ = impl; + impl_->set_pnode(this); +} + +Node::~Node() { + impl_->set_pnode(nullptr); + auto node = impl_->node(); + if (node != nullptr) { + impl_->set_node(nullptr); + delete node; + } +} + +Net::Net(std::string name) { + auto impl = std::make_shared(); + if (impl == nullptr) { + MS_LOG(ERROR) << "Cannot allocate network implementation"; + return; + } + impl_ = impl; + impl_->set_pnet(std::shared_ptr(this)); + auto netl = new (std::nothrow) lite::Net(name); + if (netl == nullptr) { + MS_LOG(ERROR) << "Cannot allocate network lite"; + return; + } + netl->set_impl(impl); + impl_->set_net(netl); +} + +Net::Net() : Net("") {} + +Net::Net(const Graph &g) { + auto net = NetImpl::GetNet(g); + impl_ = net->impl_; +} + +void Net::Add(NetBase *element) { MS_LOG(WARNING) << "Only sequential can add element"; } + +uint32_t Net::type() { return kNetType; } + +std::vector Net::construct(const std::vector &inputs) { + auto in = Expr::convert(inputs); + auto out = impl_->net()->construct(in); + return Expr::convert(out); +} + +std::vector Net::operator()(const std::vector &inputs) { + auto in = Expr::convert(inputs); + auto x = construct(inputs); + impl_->net()->input_ = in; + auto out = Expr::convert(x); + impl_->net()->output_ = out; + impl_->net()->real_output_ = out; + return x; +} +void Net::Register(Net *net, std::string &&name) { + if (net != nullptr) { + auto net_lite = net->impl_->net(); + impl_->net()->Register(net_lite, std::move(name)); + } +} + +void Net::Register(Node *node, std::string &&name) { + if (node != nullptr) { + auto impl = NodeImpl::GetImpl(node); + if (impl == nullptr) { + MS_LOG(ERROR) << "missing implementation"; + return; + } + auto node_lite = impl->node(); + impl_->net()->Register(node_lite, std::move(name)); + } +} + +std::shared_ptr Net::trainable_params() { + auto node_set = std::make_shared(); + if (node_set == nullptr) { + MS_LOG(ERROR) << "new NodeSet failed."; + return nullptr; + } + node_set->set_ = impl_->net()->trainable_params(); + return node_set; +} + +const std::vector Net::InputShape(int idx) { return impl_->InputShape(idx); } +const std::vector Net::OutputShape(int idx) { return impl_->OutputShape(idx); } + +Net::~Net() { + if (impl_ != nullptr) { + if ((impl_->pnet() == nullptr) || (impl_->pnet() == this)) { + impl_->set_pnet(nullptr); + impl_->set_net(nullptr); + impl_.reset(); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/expression/net_impl.cc b/mindspore/lite/src/cxx_api/expression/net_impl.cc new file mode 100644 index 00000000000..c2a159a771c --- /dev/null +++ b/mindspore/lite/src/cxx_api/expression/net_impl.cc @@ -0,0 +1,222 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/expression/net_impl.h" +#include +#include +#include "include/api/serialization.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" +#include "src/cxx_api/model/model_impl.h" + +namespace { +constexpr size_t kFlatbuffersBuilderInitSize = 1024; +}; + +namespace mindspore { +Sequential::Sequential() {} + +lite::Node *Sequential::GetNode(NetBase *element) { + lite::Node *lite_node = nullptr; + switch (element->type()) { + case kNodeType: { + Node *node = reinterpret_cast(element); + auto impl = NodeImpl::GetImpl(node); + if (impl == nullptr) { + MS_LOG(ERROR) << "cannot find node implement"; + return nullptr; + } + lite_node = impl->node(); + break; + } + case kNetType: { + auto net = reinterpret_cast(element); + auto impl = NetImpl::GetImpl(net); + if (impl == nullptr) { + MS_LOG(ERROR) << "cannot find node implement"; + return nullptr; + } + lite_node = impl->net(); + break; + } + } + return lite_node; +} + +void Sequential::Add(NetBase *element) { + lite::Node *node = GetNode(element); + auto impl = NetImpl::GetImpl(this); + if (impl == nullptr) { + MS_LOG(ERROR) << "No implementation"; + return; + } + impl->net()->Add(node); +} + +NetWithLoss::NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) { + REG(net_); + Register(loss_fn_, "_loss_fn"); +} + +std::vector NetWithLoss::construct(const std::vector &inputs) { + if (inputs.size() != C2NUM) { + MS_LOG(ERROR) << "need 2 inputs for loss"; + return {}; + } + auto input = inputs[FIRST_INPUT]; + auto label = inputs[SECOND_INPUT]; + auto x = (*net_)({input}); + x = (*loss_fn_)({x[FIRST_INPUT], label}); + return x; +} + +NetImpl::~NetImpl() {} + +NetImpl::NetImpl(std::shared_ptr p) { pnet_ = p; } + +NetImpl::NetImpl(Graph *g) { pnet_ = g->net_data_->net(); } + +std::vector MS_API NetImpl::construct(const std::vector &inputs) { + auto in = Expr::convert(inputs); + auto out = pnet_->construct(in); + return Expr::convert(out); +} + +Net *MS_API NetImpl::Connect(std::shared_ptr net, lite::Net *lnet) { + auto impl = GetImpl(net.get()); + if (impl == nullptr) { + MS_LOG(ERROR) << "missing implementation"; + return nullptr; + } + impl->set_pnet(net); + lnet->set_impl(impl); + impl->set_net(lnet); + return net.get(); +} + +Status NetImpl::Import(const char *model_buf, Graph *graph) { + auto mg = schema::GetMetaGraph(model_buf); + auto net = new (std::nothrow) Net(); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate network"; + return kLiteMemoryFailed; + } + lite::Import import; + auto lite_net = import.ImportMs(mg); + if (lite_net == nullptr) { + MS_LOG(ERROR) << "failed to import net"; + return kLiteMemoryFailed; + } + lite_net->SetRealOutput(); + Connect(net->shared_from_this(), lite_net); + *graph = Graph(net); + return kSuccess; +} + +Status NetImpl::TrainNet(Node *optimizer, const std::vector &inputs) { + auto impl = NodeImpl::GetImpl(optimizer); + if (impl == nullptr) { + MS_LOG(ERROR) << "missing implementation "; + return kLiteNullptr; + } + auto opt = impl->node(); + auto in = Expr::convert(inputs); + auto ret_net = net()->TrainNet(opt, in); + if (ret_net == nullptr) { + MS_LOG(ERROR) << "failed to train network"; + return kLiteNullptr; + } + return kSuccess; +} + +std::unique_ptr NetImpl::MakeMs() { + auto mgraph = std::make_unique(Graph::Type::kExecutableGraph); + if (mgraph == nullptr) { + MS_LOG(ERROR) << "Cannot allocate graph"; + return nullptr; + } + auto trained_graph = net()->MakeMs(); + if (trained_graph == nullptr) { + MS_LOG(ERROR) << "cannot create flat buffer"; + return nullptr; + } + flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); + auto offset = schema::MetaGraph::Pack(builder, trained_graph.get()); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + auto buffer = builder.GetBufferPointer(); + size_t size = builder.GetSize(); + auto status = Serialization::Load(buffer, size, mindspore::kMindIR, mgraph.get()); + if (status != kSuccess) { + MS_LOG(ERROR) << "failed to load flatbuffer to graph"; + return nullptr; + } + return mgraph; +} + +const std::vector NetImpl::InputShape(int idx) { return net_->InputShape(idx); } + +const std::vector NetImpl::OutputShape(int idx) { return net_->OutputShape(idx); } + +void NetImpl::ReplaceNet(Graph *g, std::shared_ptr n) { g->net_data_->net().swap(n); } + +ExpressionLoader expression_registrator = CreateExpressionLoader(NetImpl::Import); + +namespace NN { +Net *Sequential() { + auto net = new (std::nothrow) mindspore::Sequential(); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate "; + return nullptr; + } + auto netl = lite::NN::Sequential(); + return NetImpl::Connect(net->shared_from_this(), netl); +} + +Net *NetWithLoss(Net *net, Node *loss) { + auto loss_net = new (std::nothrow) mindspore::NetWithLoss(net, loss); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate loss net"; + return nullptr; + } + return loss_net; +} + +Graph *GraphWithLoss(Graph *graph, Node *loss) { + auto net = NetImpl::GetNet(*graph); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate network"; + return nullptr; + } + auto loss_net = NetWithLoss(net.get(), loss); + if (loss_net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate network"; + return nullptr; + } + NetImpl::ReplaceNet(graph, loss_net->shared_from_this()); + return graph; +} + +Net *NetWithLoss(Graph *g, Node *loss) { + auto net = new (std::nothrow) Net(*g); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate net"; + return nullptr; + } + return NetWithLoss(net, loss); +} +} // namespace NN +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/expression/net_impl.h b/mindspore/lite/src/cxx_api/expression/net_impl.h new file mode 100644 index 00000000000..b33b3421c56 --- /dev/null +++ b/mindspore/lite/src/cxx_api/expression/net_impl.h @@ -0,0 +1,95 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_ +#define MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_ + +#include +#include +#include +#include +#include +#include "include/api/cfg.h" +#include "include/api/data_type.h" +#include "include/api/graph.h" +#include "include/api/status.h" +#include "include/api/net.h" +#include "src/cxx_api/expression/node_impl.h" +#include "src/cxx_api/graph/net_data.h" +#include "src/expression/net.h" +#include "src/expression/ops.h" + +namespace mindspore { +constexpr uint32_t kNodeType = 1; +constexpr uint32_t kNetType = 2; +class Sequential : public Net { + public: + Sequential(); + void Add(NetBase *n) override; + + private: + std::vector ops_; + lite::Node *GetNode(NetBase *element); +}; + +class NetWithLoss : public Net { + public: + NetWithLoss(Net *net, Node *loss); + std::vector construct(const std::vector &inputs) override; + + private: + Net *net_{nullptr}; + Node *loss_fn_{nullptr}; +}; + +class MS_API NetImpl { + public: + virtual ~NetImpl(); + explicit NetImpl(std::shared_ptr p); + explicit NetImpl(Graph *g); + NetImpl() = default; + void set_net(lite::Net *net) { + if (net_ != nullptr) { + net_->set_impl(nullptr); + delete net_; + } + net_ = net; + } + void erase_net() { net_ = nullptr; } + void set_pnet(std::shared_ptr net) { pnet_ = net; } + Net *pnet() { return pnet_.get(); } + lite::Net *net() { return net_; } + + std::vector construct(const std::vector &inputs); + static std::shared_ptr &GetImpl(Net *net) { return net->impl_; } + static Net *Connect(std::shared_ptr net, lite::Net *lnet); + static std::shared_ptr &GetNet(const Graph &g) { return g.net_data_->net(); } + static void SetNet(Graph *g, std::shared_ptr n) { g->net_data_->set_net(n); } + static void ReplaceNet(Graph *g, std::shared_ptr n); + static Status Import(const char *model_buf, Graph *graph); + Status TrainNet(Node *optimizer, const std::vector &inputs); + const std::vector InputShape(int idx); + const std::vector OutputShape(int idx); + std::unique_ptr MakeMs(); + void Release() { pnet_.reset(); } + + private: + std::shared_ptr pnet_; + lite::Net *net_ = nullptr; +}; +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NET_IMPL_H_ diff --git a/mindspore/lite/src/cxx_api/expression/node_impl.cc b/mindspore/lite/src/cxx_api/expression/node_impl.cc new file mode 100644 index 00000000000..044f7a5a4aa --- /dev/null +++ b/mindspore/lite/src/cxx_api/expression/node_impl.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/expression/node_impl.h" +#include +#include "include/api/net.h" +#include "src/expression/ops.h" + +namespace mindspore { +Node *NodeImpl::Connect(lite::Node *lnode) { + auto node = std::make_unique(); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate node"; + return nullptr; + } + if (lnode == nullptr) { + MS_LOG(ERROR) << "lite node is null"; + return nullptr; + } + auto pnode = node.release(); + auto impl = GetImpl(pnode); + if (impl == nullptr) { + MS_LOG(ERROR) << "missing implementation"; + return nullptr; + } + impl->set_node(lnode); + lnode->set_impl(impl); + return pnode; +} +namespace NN { +std::unique_ptr Input(std::vector dims, DataType data_type, int fmt) { + auto type = static_cast(data_type); + auto lite_node = lite::NN::Input(dims, type, fmt); + return std::unique_ptr(NodeImpl::Connect(lite_node)); +} +} // namespace NN +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/expression/node_impl.h b/mindspore/lite/src/cxx_api/expression/node_impl.h new file mode 100644 index 00000000000..e4dc8d6c06c --- /dev/null +++ b/mindspore/lite/src/cxx_api/expression/node_impl.h @@ -0,0 +1,71 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_ +#define MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_ + +#include +#include +#include +#include +#include "include/api/net.h" +#include "include/api/cfg.h" +#include "include/api/data_type.h" +#include "src/expression/node.h" +#include "src/expression/expr.h" + +namespace mindspore { +using lite::EXPR; +class NodeSet { + public: + std::set set_; +}; + +class Expr : public EXPR { + public: + static std::vector convert(const std::vector &input) { + std::vector vec(input.size()); + (void)std::transform(input.begin(), input.end(), vec.begin(), [](Expr *e) { return reinterpret_cast(e); }); + return vec; + } + static std::vector convert(const std::vector &input) { + std::vector vec(input.size()); + (void)std::transform(input.begin(), input.end(), vec.begin(), [](EXPR *e) { return reinterpret_cast(e); }); + return vec; + } +}; + +class MS_API NodeImpl { + public: + std::vector operator()(const std::vector &inputs) { + auto in = Expr::convert(inputs); + auto out = (*node_)(in); + return Expr::convert(out); + } + lite::Node *node() { return node_; } + void set_node(lite::Node *node) { node_ = node; } + void set_pnode(Node *node) { pnode_ = node; } + Node *pnode() { return pnode_; } + static Node *Connect(lite::Node *lnode); + static std::shared_ptr &GetImpl(Node *node) { return node->impl_; } + + private: + Node *pnode_{nullptr}; + lite::Node *node_{nullptr}; +}; +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_CXX_API_EXPRESSION_NODE_IMPL_H_ diff --git a/mindspore/lite/src/cxx_api/graph/graph.cc b/mindspore/lite/src/cxx_api/graph/graph.cc index 2ed1c874a62..b1f3a3ddc47 100644 --- a/mindspore/lite/src/cxx_api/graph/graph.cc +++ b/mindspore/lite/src/cxx_api/graph/graph.cc @@ -16,7 +16,9 @@ #include "include/api/graph.h" #include "include/api/cell.h" +#include "include/api/net.h" #include "src/cxx_api/graph/graph_data.h" +#include "src/cxx_api/graph/net_data.h" namespace mindspore { Graph::Graph() : graph_data_(nullptr) {} @@ -25,8 +27,15 @@ Graph::Graph(const std::shared_ptr &graph_data) : graph_data_(graph_d Graph::Graph(std::shared_ptr &&graph_data) : graph_data_(graph_data) {} +Graph::Graph(Graph::Type type) : graph_type_(type) {} + Graph::~Graph() {} +Graph::Graph(Net *net) : graph_type_(kExpressionGraph) { + auto shared = std::make_shared(net->shared_from_this()); + net_data_ = shared; +} + Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } diff --git a/mindspore/lite/src/cxx_api/graph/net_data.cc b/mindspore/lite/src/cxx_api/graph/net_data.cc new file mode 100644 index 00000000000..8b4fd026267 --- /dev/null +++ b/mindspore/lite/src/cxx_api/graph/net_data.cc @@ -0,0 +1,21 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/cxx_api/graph/net_data.h" +#include "src/cxx_api/expression/net_impl.h" + +namespace mindspore { +NetData::~NetData() { net_->impl_->Release(); } +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/graph/net_data.h b/mindspore/lite/src/cxx_api/graph/net_data.h new file mode 100644 index 00000000000..b5ad1e7ea79 --- /dev/null +++ b/mindspore/lite/src/cxx_api/graph/net_data.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_ +#define MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_ + +#include +#include "include/api/net.h" + +namespace mindspore { +class NetData { + public: + explicit NetData(const std::shared_ptr &net) : net_(net) {} + virtual ~NetData(); + void set_net(std::shared_ptr net) { net_ = net; } + std::shared_ptr &net() { return net_; } + + private: + std::shared_ptr net_; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_GRAPH_NET_DATA_H_ diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 0e42c5f32ba..e3578c23a93 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -22,11 +22,15 @@ #ifdef ENABLE_LITE_ACL #include "acl/acl_base.h" #endif +#include "flatbuffers/flatbuffers.h" #include "include/api/callback/callback.h" #include "include/api/context.h" #include "include/api/dual_abi_helper.h" #include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/api/graph.h" #include "src/common/log_adapter.h" +#include "src/cxx_api/expression/net_impl.h" #include "src/cxx_api/callback/callback_adapter.h" #include "src/cxx_api/callback/callback_impl.h" #include "src/cxx_api/model/model_impl.h" diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index bf01d16ef41..0e803ad0f76 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -56,6 +56,14 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt return proto_; } +ExpressionLoader CreateExpressionLoader(ExpressionLoader loader) { + static ExpressionLoader loader_ = nullptr; + if (loader != nullptr) { + loader_ = loader; + } + return loader_; +} + Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &ms_context) { if (model_data == nullptr) { diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index abd5e6194d3..e58393e6f53 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -49,6 +49,9 @@ typedef std::shared_ptr(CreateTrainSessionProto)(std::shared_ lite::InnerContext *context); CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); +using ExpressionLoader = std::function; +ExpressionLoader CreateExpressionLoader(ExpressionLoader loader = nullptr); + namespace session { class Metrics; class TrainLoopCallBack; @@ -90,6 +93,7 @@ class ModelImpl { static bool CheckModelSupport(const std::string &device_type, ModelType model_type); bool IsTrainModel(); + std::unique_ptr BuildTrain(Node *optimizer, std::vector inputs); Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); Status SetLearningRate(float learning_rate); float GetLearningRate(); diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc index ecd92a6bdf3..3f2c4820c75 100644 --- a/mindspore/lite/src/cxx_api/serialization.cc +++ b/mindspore/lite/src/cxx_api/serialization.cc @@ -20,6 +20,7 @@ #include "include/api/graph.h" #include "include/api/types.h" #include "include/model.h" +#include "src/cxx_api/expression/net_impl.h" #include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/model/model_impl.h" #include "src/cxx_api/converters.h" @@ -27,6 +28,8 @@ #include "src/lite_session.h" namespace mindspore { +std::function ExpressionCallback; + Key::Key(const char *dec_key, size_t key_len) { len = 0; if (key_len >= max_key_len) { @@ -115,19 +118,31 @@ Status Serialization::Load(const std::vector &file, ModelType model_type, MS_LOG(ERROR) << "Read model file failed"; return kLiteNullptr; } - auto model = - std::shared_ptr(lite::ImportFromBuffer(static_cast(model_buf), model_size, true)); - if (model == nullptr) { - MS_LOG(ERROR) << "New model failed."; - return kLiteNullptr; + if (graph->IsExecutable()) { + auto model = + std::shared_ptr(lite::ImportFromBuffer(static_cast(model_buf), model_size, true)); + if (model == nullptr) { + MS_LOG(ERROR) << "New model failed."; + return kLiteNullptr; + } + auto graph_data = std::shared_ptr(new (std::nothrow) Graph::GraphData(model)); + if (graph_data == nullptr) { + MS_LOG(ERROR) << "New graph data failed."; + return kLiteMemoryFailed; + } + *graph = Graph(graph_data); + return kSuccess; + } else { + auto loader = CreateExpressionLoader(); + if (loader == nullptr) { + MS_LOG(ERROR) << "Unsupported Feature."; + delete[] model_buf; + return kLiteError; + } + loader(model_buf, graph); + delete[] model_buf; + return kSuccess; } - auto graph_data = std::shared_ptr(new (std::nothrow) Graph::GraphData(model)); - if (graph_data == nullptr) { - MS_LOG(ERROR) << "New graph data failed."; - return kLiteMemoryFailed; - } - *graph = Graph(graph_data); - return kSuccess; } Status Serialization::Load(const std::vector> &files, ModelType model_type, diff --git a/mindspore/lite/src/cxx_api/train/model.cc b/mindspore/lite/src/cxx_api/train/model.cc index 80d840f86ca..ecb5e6fe530 100644 --- a/mindspore/lite/src/cxx_api/train/model.cc +++ b/mindspore/lite/src/cxx_api/train/model.cc @@ -17,6 +17,7 @@ #include "include/api/model.h" #include "include/api/types.h" #include "include/api/context.h" +#include "include/api/net.h" #include "include/api/callback/callback.h" #include "include/api/dual_abi_helper.h" #include "src/cxx_api/model/model_impl.h" @@ -106,11 +107,39 @@ Status Model::Evaluate(std::shared_ptr ds, std::vector inputs, + const std::shared_ptr &model_context, const std::shared_ptr &train_cfg) { + std::stringstream err_msg; + if (impl_ == nullptr) { + impl_ = std::make_shared(); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteFileError; + } + } + auto lossGraph = lossGraphCell.GetGraph(); + if (lossGraph == nullptr) { + err_msg << "Invalid null graph"; + MS_LOG(ERROR) << err_msg.str(); + return Status(kLiteNullptr, err_msg.str()); + } + impl_->SetContext(model_context); + impl_->SetConfig(train_cfg); + impl_->SetGraph(lossGraph); + auto graph = impl_->BuildTrain(optimizer, inputs); + auto status = Build(GraphCell(*graph), model_context, train_cfg); + if (status != mindspore::kSuccess) { + MS_LOG(ERROR) << "Error " << status << " during model build"; + return status; + } + return kSuccess; // status +} + Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr &context, const std::shared_ptr &train_cfg) { std::stringstream err_msg; if (impl_ == nullptr) { - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteFileError; @@ -137,5 +166,4 @@ Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const st } return kSuccess; } - } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/train/model_impl.cc b/mindspore/lite/src/cxx_api/train/model_impl.cc index 505bbcb315d..bc43b009cda 100644 --- a/mindspore/lite/src/cxx_api/train/model_impl.cc +++ b/mindspore/lite/src/cxx_api/train/model_impl.cc @@ -18,15 +18,18 @@ #include #include #include +#include "src/cxx_api/expression/node_impl.h" #include "include/api/types.h" #include "include/api/context.h" #include "include/api/dual_abi_helper.h" +#include "include/api/serialization.h" #include "include/lite_session.h" #include "include/context.h" #include "include/api/callback/callback.h" #include "include/api/metrics/metrics.h" #include "src/lite_model.h" #include "src/runtime/inner_allocator.h" +#include "src/cxx_api/expression/net_impl.h" #include "src/cxx_api/converters.h" #include "src/cxx_api/graph/graph_data.h" #include "src/cxx_api/tensor/tensor_impl.h" @@ -79,6 +82,32 @@ Status ModelImpl::BuildTransferLearning(const std::shared_ptr &backbone, return kLiteError; } +std::unique_ptr ModelImpl::BuildTrain(Node *optimizer, std::vector inputs) { + auto opt_impl = NodeImpl::GetImpl(optimizer); + if (opt_impl == nullptr) { + MS_LOG(ERROR) << "missing optimizer node implementation"; + return nullptr; + } + auto opt = opt_impl->node(); + auto in = Expr::convert(inputs); + auto net_impl = NetImpl::GetImpl(graph_->net_data_->net().get()); + if (net_impl == nullptr) { + MS_LOG(ERROR) << "missing net implementation"; + return nullptr; + } + auto trained_net = net_impl->net()->TrainNet(opt, in); + if (trained_net == nullptr) { + MS_LOG(ERROR) << "failed to train network"; + return nullptr; + } + auto mgraph = net_impl->MakeMs(); + if (mgraph == nullptr) { + MS_LOG(ERROR) << "failed to create graph"; + return nullptr; + } + return mgraph; +} + Status ModelImpl::PrepareMetrics(Model *model, std::vector *out_ms, std::vector *adapter_ms) { if (out_ms == nullptr || adapter_ms == nullptr) { diff --git a/mindspore/lite/src/expression/cfg.h b/mindspore/lite/src/expression/cfg.h new file mode 100644 index 00000000000..e590d2b7ddd --- /dev/null +++ b/mindspore/lite/src/expression/cfg.h @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ + +#include +#include + +namespace mindspore { +namespace lite { +class ConvConfig { + public: + ConvConfig() = default; + int in_channel_ = 3; /**< The channel number of the input of the Conv2d layer */ + int out_channel_ = 3; /**< The channel number of the output tensor of the Conv2d layer */ + std::vector kernel_size_ = {3, 3}; /**< Specifies the height and width of the 2D convolution kernel. */ + std::vector stride_ = {1, 1}; /**< The movement stride of the 2D convolution kernel */ + std::vector padding_ = {0, 0, 0, 0}; /**< The top, bottom, left, and right padding input */ + std::vector dilation_ = {1, 1}; /**< diletion height and width*/ + int group_ = 1; // < Splits filter into groups, `in_channels` and `out_channels` must be + // divisible by `group`. If the group is equal to `in_channels` and `out_channels`, + // this 2D convolution layer also can be called 2D depthwise convolution layer */ + bool has_bias = false; /** < Whether the Conv2d layer has a bias parameter */ + std::string weight_init_ = + "normal"; /**< Initialization method of weight parameter ("normal","uniform", "ones", "zeros") */ + std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid", "pad" */ + + private: + std::string bias_init_ = "zeros"; + std::string data_format; +}; + +class DenseConfig { + public: + int in_channels_; /**< The number of channels in the input space */ + int out_channels_; /**< The number of channels in the output space */ + bool has_bias_ = false; /** Specifies whether the layer uses a bias vector **/ + private: + std::string weight_init_ = "normal"; + std::string bias_init_ = "zeros"; + std::string activation_ = "none"; +}; + +class PoolingConfig { + public: + PoolingConfig() = default; + std::vector kernel_size_ = {1, 1}; /**< Specifies the height and width of the 2D kernel. */ + std::vector stride_ = {1, 1}; /**< The movement stride of the 2D kernel */ + std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid" */ +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ diff --git a/mindspore/lite/src/expression/export.cc b/mindspore/lite/src/expression/export.cc new file mode 100644 index 00000000000..a86c54a6e9b --- /dev/null +++ b/mindspore/lite/src/expression/export.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/expression/export.h" +#include "src/expression/ops.h" +#include "src/common/utils.h" +#include "nnacl/conv_parameter.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +constexpr static int kFmkVal = 3; + +int ExportSession::Init(const std::string model_name, std::string version) { + meta_graph_ = new (std::nothrow) mindspore::schema::MetaGraphT(); + if (meta_graph_ == nullptr) { + MS_LOG(ERROR) << "cannot allocate meta_graph"; + return RET_ERROR; + } + meta_graph_->fmkType = kFmkVal; + meta_graph_->name = model_name; + meta_graph_->version = version; + return RET_OK; +} + +bool ExportSession::IsToDependOnly(EXPR *expr) { + auto itr = outmap_.find(expr); + if (itr != outmap_.end() && !itr->second.empty()) { + for (auto expr : itr->second) { + auto node = expr->node(); + if (node->primitive() != schema::PrimitiveType_Depend) return false; + } + return true; + } + return false; +} + +int ExportSession::SetInputOutput(const std::vector &inputs, const std::vector &outputs) { + for (auto &in : inputs) { + auto id = GetOutput(in); + meta_graph_->inputIndex.push_back(id); + } + for (auto &out : outputs) { + auto id = GetOutput(out); + meta_graph_->outputIndex.push_back(id); + } + auto sub_graph = std::make_unique(); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "cannot allocate SubGraphT"; + return RET_ERROR; + } + auto model_name = meta_graph_->name; + sub_graph->name = model_name + "_subgraph"; + sub_graph->inputIndices = meta_graph_->inputIndex; + sub_graph->outputIndices = meta_graph_->outputIndex; + for (size_t i = 0; i < meta_graph_->nodes.size(); i++) sub_graph->nodeIndices.push_back(i); + for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) sub_graph->tensorIndices.push_back(i); + meta_graph_->subGraph.emplace_back(std::move(sub_graph)); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/export.h b/mindspore/lite/src/expression/export.h new file mode 100644 index 00000000000..8009e080ecc --- /dev/null +++ b/mindspore/lite/src/expression/export.h @@ -0,0 +1,52 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "src/expression/expr.h" + +namespace mindspore { +namespace schema { +struct MetaGraphT; +} +namespace lite { +class ExportSession { + public: + explicit ExportSession(std::map> &outmap) : outmap_(outmap) {} + int Init(const std::string model_name, std::string version); + void UpdateOutput(EXPR *expr, int id) { output_tensors_[expr] = id; } + int GetOutput(EXPR *expr) { return output_tensors_.at(expr); } + schema::MetaGraphT *&meta_graph() { return meta_graph_; } + int SetInputOutput(const std::vector &inputs, const std::vector &outputs); + bool IsToDependOnly(EXPR *expr); + + private: + schema::MetaGraphT *meta_graph_{nullptr}; + std::unordered_map output_tensors_; // output tensors per EXPR + std::map> &outmap_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ diff --git a/mindspore/lite/src/expression/expr.cc b/mindspore/lite/src/expression/expr.cc new file mode 100644 index 00000000000..b27853d2692 --- /dev/null +++ b/mindspore/lite/src/expression/expr.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/expression/expr.h" +#include "src/expression/node.h" + +namespace mindspore { +namespace lite { +std::string EXPR::name() { return node_->name(); } +void EXPR::Travers(std::function cb) { + if (!visited) { + visited = true; + for (auto &itr : params_) { + if (cb(this, itr)) { + itr->Travers(cb); + } + } + } +} + +void EXPR::Replace(EXPR **old, EXPR **n, std::vector *to_delete) { + if (!visited) { + visited = true; + for (auto &itr : params_) + if (itr == *old) { + to_delete->push_back(itr->node()); + itr = *n; + } + for (auto &itr : params_) itr->Replace(old, n, to_delete); + } +} + +void EXPR::Replace(const std::vector &vec, std::vector *old, std::vector *n) { + std::vector to_delete; + for (auto &e : vec) { + for (std::size_t i = 0; i < old->size(); i++) { + e->Replace(&old->at(i), &n->at(i), &to_delete); + } + } + for (auto &itr : to_delete) delete itr; + for (auto e : vec) e->Clear(); +} + +void EXPR::Clear() { + EXPR *item = this; + if (visited == false) return; + visited = false; + while (item->params_.size() == 1) { + item = item->params_.front(); + if (item->visited == false) return; + item->visited = false; + } + for (auto &itr : item->params_) itr->Clear(); +} + +void EXPR::Clear(std::vector vec) { + for (auto e : vec) e->Clear(); +} + +void EXPR::CreateOutputMap(std::vector vec, std::map> *outmap) { + for (auto e : vec) { + e->Travers([&](EXPR *e, EXPR *itr) { + (*outmap)[itr].push_back(e); + return true; + }); + } + Clear(vec); +} + +void EXPR::PrintDot(std::vector vec) { + std::cout << "digraph \"expr\" { " << std::endl; + for (auto e : vec) { + e->Travers([](EXPR *e, EXPR *itr) { + std::cout << "\"" << itr->node_->name() << "\"->" + << "\"" << e->node_->name() << "\"" << std::endl; + return true; + }); + } + std::cout << "}" << std::endl; + Clear(vec); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/expr.h b/mindspore/lite/src/expression/expr.h new file mode 100644 index 00000000000..97940d5fc10 --- /dev/null +++ b/mindspore/lite/src/expression/expr.h @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include "include/ms_tensor.h" +#include "include/api/format.h" + +namespace mindspore { +namespace lite { +class Node; + +class EXPR { + public: + explicit EXPR(Node *node) : node_(node) { SetSize(1); } + static void PrintDot(std::vector vec); + static void Replace(const std::vector &vec, std::vector *old, std::vector *n); + static void CreateOutputMap(std::vector vec, std::map> *outmap); + static void Clear(std::vector vec); + void Travers(std::function cb); + std::string name(); + EXPR *GetInput(int idx) { return params_.at(idx); } + void set_node(Node *node) { node_ = node; } + Node *node() { return node_; } + bool visited = false; + void set_params(std::vector params) { params_ = params; } + void set_params(int idx, EXPR *expr) { params_[idx] = expr; } + void add_params(EXPR *e) { params_.push_back(e); } + std::vector ¶ms() { return params_; } + EXPR *params(int i) { return params_[i]; } + void SetSize(int n) { params_.resize(n); } + void SetDims(std::vector dims) { dims_ = dims; } + std::vector &dims() { return dims_; } + void set_format(int fmt) { format_ = fmt; } + int format() { return format_; } + void set_data_type(TypeId data_type) { data_type_ = data_type; } + TypeId data_type() { return data_type_; } + + private: + void Replace(EXPR **old, EXPR **n, std::vector *to_delete); + std::vector params_; + Node *node_{nullptr}; + void Clear(); + std::vector dims_; + int format_ = NHWC; + TypeId data_type_ = kNumberTypeFloat32; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ diff --git a/mindspore/lite/src/expression/import.cc b/mindspore/lite/src/expression/import.cc new file mode 100644 index 00000000000..8b2bbf9c8c5 --- /dev/null +++ b/mindspore/lite/src/expression/import.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/expression/import.h" +#include "ops/populate/populate_register.h" +#include "src/expression/ops.h" +#include "src/expression/ops/activation.h" +#include "src/expression/ops/batchnorm.h" +#include "src/expression/ops/biasadd.h" +#include "src/expression/ops/conv.h" +#include "src/expression/ops/dense.h" +#include "src/expression/ops/pooling.h" +#include "src/expression/ops/reshape.h" +#include "src/expression/ops/transpose.h" + +namespace mindspore { +namespace lite { +std::unordered_map ImportReg::import_map_; + +import_func ImportReg::GetImportFunc(mindspore::schema::PrimitiveType type) { + auto f = import_map_.find(type); + if (f == import_map_.end()) { + return nullptr; + } + return f->second; +} + +OpParameter *Import::GetAttr(const schema::Primitive *prim) { + auto parameter_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR); + if (parameter_gen == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); + return nullptr; + } + auto parameter = parameter_gen(prim); + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is nullptr."; + return nullptr; + } + return parameter; +} + +std::unique_ptr Import::CreateNode(const schema::CNode *cnode) { + auto param = GetAttr(cnode->primitive()); + auto type = cnode->primitive()->value_type(); + auto fn = ImportReg::GetImportFunc(type); + if (fn == nullptr) { + MS_LOG(ERROR) << "Cannot find importer for " << schema::EnumNamePrimitiveType(type); + return nullptr; + } + auto node = fn(); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate node" << cnode->name()->str(); + return nullptr; + } + node->SetOpParam(param); + node->set_name(cnode->name()->str()); + node->set_primitive(type); + return std::unique_ptr(node); +} + +Net *Import::ImportMs(std::string file_name) { + std::ifstream infile; + infile.open(file_name, std::ios::binary | std::ios::in); + if (!infile.good()) { + MS_LOG(ERROR) << "cannot read " << file_name << std::endl; + return nullptr; + } + infile.seekg(0, std::ios::end); + int length = infile.tellg(); + infile.seekg(0, std::ios::beg); + auto data_ptr = std::make_unique(length); + auto *data = data_ptr.get(); + infile.read(reinterpret_cast(data), length); + infile.close(); + flatbuffers::Verifier verifier = flatbuffers::Verifier(reinterpret_cast(data), length); + bool res = schema::VerifyMetaGraphBuffer(verifier); + if (res != true) { + MS_LOG(ERROR) << "fault file: " << file_name << "(" << length << ")\n"; + return nullptr; + } else { + MS_LOG(INFO) << "verify pass file: " << file_name << "(" << length << ")\n"; + } + buffer_ = data_ptr.get(); + auto metaGraph = schema::GetMetaGraph(data_ptr.release()); + return ImportMs(metaGraph); +} + +Net *Import::ImportMs(const schema::MetaGraph *metaGraph) { + if (metaGraph == nullptr) { + MS_LOG(ERROR) << "null input"; + return nullptr; + } + std::string NetName = "Network"; + if (metaGraph->name() != nullptr) NetName = metaGraph->name()->str(); + auto net = std::make_unique(NetName); + std::unordered_map outputs; + // save inputs + for (size_t i = 0; i < metaGraph->inputIndex()->size(); i++) { + auto tensor_id = metaGraph->inputIndex()->Get(i); + const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id); + auto input = new (std::nothrow) InputM(tensor); + if (input == nullptr) { + MS_LOG(ERROR) << "Cannot allocate input"; + return nullptr; + } + auto e = input->expr(); + outputs[tensor_id] = e; + net->PushInput(e); + } + for (size_t i = 0; i < metaGraph->nodes()->size(); i++) { + auto Cnode = metaGraph->nodes()->Get(i); + std::vector param_tensors; + for (size_t j = 0; j < Cnode->inputIndex()->size(); j++) { + int tensor_id = Cnode->inputIndex()->Get(j); + const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id); + auto iter = outputs.find(tensor_id); + if (iter == outputs.end()) { + // create value node if not exist + if (tensor->nodeType() != NodeType::NodeType_CNode) { + auto valnode = new (std::nothrow) InputM(tensor); + if (valnode == nullptr) { + MS_LOG(ERROR) << "Cannot allocate valnode"; + return nullptr; + } + outputs[tensor_id] = valnode->expr(); + param_tensors.push_back(valnode->expr()); + net->PushOp(valnode); + } else { + MS_LOG(ERROR) << "did not found input tensor " << tensor_id; + return nullptr; + } + } else { + param_tensors.push_back(iter->second); + } + } + // create expression from node // + auto node = CreateNode(Cnode); + if (node != nullptr) { + node->SetOutputs(Cnode->outputIndex()->size()); + std::vector e = (*node)(param_tensors); + for (size_t j = 0; j < Cnode->outputIndex()->size(); j++) { + int tensor_id = Cnode->outputIndex()->Get(j); + outputs[tensor_id] = e.at(j); + } + } else { + MS_LOG(ERROR) << "failed to create node " << Cnode->name(); + return nullptr; + } + auto node_ptr = node.release(); + net->PushOp(node_ptr); + node_ptr->SetLearn(); + } + for (size_t i = 0; i < metaGraph->outputIndex()->size(); i++) { + auto tensor_id = metaGraph->outputIndex()->Get(i); + auto iter = outputs.find(tensor_id); + if (iter == outputs.end()) { + MS_LOG(ERROR) << "could not find source for tensor " << tensor_id; + return nullptr; + } else { + net->PushOutput(iter->second); + } + } + return net.release(); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/import.h b/mindspore/lite/src/expression/import.h new file mode 100644 index 00000000000..3d4f301ed08 --- /dev/null +++ b/mindspore/lite/src/expression/import.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ + +#include +#include +#include +#include "nnacl/op_base.h" +#include "src/expression/net.h" + +namespace mindspore { +namespace lite { +using import_func = std::function; + +template +Node *ReturnNode() { + return new (std::nothrow) T(); +} + +class ImportReg { + public: + explicit ImportReg(mindspore::schema::PrimitiveType type, import_func func) { import_map_[type] = func; } + static import_func GetImportFunc(mindspore::schema::PrimitiveType type); + + private: + static std::unordered_map import_map_; +}; + +class Import { + private: + int8_t *buffer_ = nullptr; + OpParameter *GetAttr(const schema::Primitive *prim); + std::unique_ptr CreateNode(const schema::CNode *cnode); + + public: + Net *ImportMs(const schema::MetaGraph *meta_graph); + Net *ImportMs(std::string file_name); + ~Import() { + delete[] buffer_; + buffer_ = nullptr; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ diff --git a/mindspore/lite/src/expression/net.cc b/mindspore/lite/src/expression/net.cc new file mode 100644 index 00000000000..e54acddcb22 --- /dev/null +++ b/mindspore/lite/src/expression/net.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/net.h" +#include +#include "src/cxx_api/expression/net_impl.h" +#include "src/expression/ops.h" +#include "src/expression/export.h" +#include "src/expression/ops/addn.h" +#include "src/expression/ops/arithmetic.h" +#include "src/common/storage.h" +#include "tools/common/meta_graph_serializer.h" +namespace mindspore { +namespace lite { +void Net::update_name(std::string name) { + if (!this->name().empty()) + Node::update_name(name); + else + set_name(name); + for (auto &itr : ops_) { + itr->update_name(name); + } +} + +std::vector Net::operator()(const std::initializer_list &&inputs) { + std::vector vec = inputs; + std::vector x; + if (impl_ == nullptr) { + x = construct(inputs); + } else { + x = impl_->construct(vec); + } + return x; +} + +std::vector Net::operator()(const std::vector &inputs) { + std::vector x; + if (impl_ == nullptr) { + x = construct(inputs); + } else { + x = impl_->construct(inputs); + } + input_ = inputs; + output_ = x; + real_output_ = x; + return x; +} + +std::vector Net::construct(const std::vector &inputs) { + if (!output_.empty()) { + if (input_.size() != inputs.size()) { + MS_LOG(ERROR) << "input size mismatch, should be " << input_.size() << " got " << inputs.size(); + return {}; + } + auto in_ptr = inputs; + EXPR::Replace(output_, &input_, &in_ptr); + } else { + MS_LOG(ERROR) << "no network construction function"; + } + return output_; +} + +void Net::TopoSortUtil(Node *node, std::stack *stack) { + visited_.insert(node); + for (size_t i = 0; i < node->OutputsNum(); i++) { + auto expr = node->expr(i); + auto itr = outmap_.find(expr); + if (itr != outmap_.end()) { + for (auto &e : itr->second) + if (visited_.find(e->node()) == visited_.end()) { + TopoSortUtil(e->node(), stack); + } + } + } + stack->push(node); +} + +std::vector Net::Sort() { + std::stack stack; + outmap_.clear(); + EXPR::CreateOutputMap(output_, &outmap_); + for (auto &itr : outmap_) { + EXPR *e = itr.first; + if (visited_.find(e->node()) == visited_.end()) { + TopoSortUtil(e->node(), &stack); + } + } + std::vector res; + while (stack.empty() == false) { + res.push_back(stack.top()); + stack.pop(); + } + visited_.clear(); + return res; +} + +std::unique_ptr Net::MakeMs() { + auto nodes = Sort(); + auto s = new (std::nothrow) ExportSession(outmap_); + if (s == nullptr) { + MS_LOG(ERROR) << "Cannot allocate export session"; + return nullptr; + } + session_.reset(s); + session_->Init(name(), Version()); + for (auto node : nodes) { + auto res = node->MakeEntry(session_.get()); + if (res != RET_OK) { + MS_LOG(ERROR) << "failed in MakeEntry: " << node->name(); + return nullptr; + } + } + session_->SetInputOutput(input_, real_output_); + auto res = session_->meta_graph(); + return std::unique_ptr(res); +} + +std::unique_ptr Net::MakeMs(const std::string file_name) { + auto graph = MakeMs(); + Save(*graph, file_name); + return graph; +} + +std::set Net::trainable_params() { + std::set res; + for (auto &node : ops_) { + res.merge(node->trainable_params()); + } + return res; +} + +int Net::BuildGrad(Node *optimizer) { + std::set learn = optimizer->trainable_params(); + auto NetOrder = Sort(); + optimizer_.reset(optimizer); + optimizer->AddNetOutput(&output_); + std::map, EXPR *> backprop; + for (auto itr = NetOrder.rbegin(); itr != NetOrder.rend(); itr++) { + Node *node = *itr; + EXPR *yt = nullptr; + if (node->primitive() == schema::PrimitiveType_NONE) continue; + if (outmap_.find(node->expr()) == outmap_.end() || outmap_[node->expr()].size() == 0) { + yt = node->expr(); + } else { + std::vector add_params; + for (auto &output : outmap_[node->expr()]) { + auto link = std::make_pair(node->expr(), output); + auto grad = backprop[link]; + add_params.push_back(grad); + } + if (add_params.size() == 1) { + yt = add_params.front(); + } else { + auto addn = new (std::nothrow) AddN(0); + if (addn == nullptr) { + MS_LOG(ERROR) << "Cannot allocate add operator"; + return RET_ERROR; + } + PushOp(addn); + addn->update_name(name()); + yt = (*addn)(add_params).front(); + } + } + auto inGrads = node->Grad(yt); + for (size_t i = 0; i < node->inputs().size(); i++) { + EXPR *inGrad{nullptr}; + if (i < inGrads.size()) { + inGrad = inGrads[i]; + } else { + inGrad = nullptr; + } + auto input = node->input(i); + if (learn.find(input->node()) != learn.end()) { + auto opt = optimizer->Clone(inGrad, input); + if (opt.size() == 0) { + MS_LOG(ERROR) << "failed to create optimizer"; + return RET_ERROR; + } + if (inGrad == nullptr) { + MS_LOG(ERROR) << "illegal null value for grad"; + return RET_ERROR; + } + if (opt.size() == 0) { + MS_LOG(ERROR) << "optimizer for " << input->node()->name() << " failure"; + return RET_ERROR; + } + auto opt_op = opt.at(0)->node(); + PushOp(opt_op); + opt_op->update_name(node->name()); + output_.push_back(opt.at(0)); + } + auto link = std::make_pair(input, node->expr()); + backprop[link] = inGrad; + } + } + return RET_OK; +} + +std::vector Net::add(const std::vector &input) { + auto _add = NN::Add(); + _add->set_name(name() + "/" + _add->name()); + ops_.push_back(_add); + return (*_add)(input); +} + +Net *Net::TrainNet(Node *optimizer, Node *loss_fn, const std::vector &inputs) { + auto net = new (std::nothrow) NetWithLoss(this, loss_fn); + if (net == nullptr) { + MS_LOG(ERROR) << "Cannot allocate loss network"; + return nullptr; + } + return net->TrainNet(optimizer, inputs); +} + +Net *Net::TrainNet(Node *optimizer, const std::vector &inputs) { + auto x = (*this)(inputs); + auto res = BuildGrad(optimizer); + if (res != RET_OK) { + MS_LOG(ERROR) << "Build gradient network failed"; + return nullptr; + } + real_output_ = x; + return this; +} + +int Net::Save(const schema::MetaGraphT &graph, std::string file_name) { return Storage::Save(graph, file_name); } + +const std::vector Net::OutputShape(int idx) { + if (static_cast(idx) >= real_output_.size()) { + MS_LOG(ERROR) << "index (" << idx << ") exceed output size (" << real_output_.size() << ")"; + return {}; + } + return real_output_.at(idx)->dims(); +} + +const std::vector Net::InputShape(int idx) { + if (static_cast(idx) >= input_.size()) { + MS_LOG(ERROR) << "index (" << idx << ") exceed input size (" << input_.size() << ")"; + return {}; + } + return input_.at(idx)->dims(); +} + +Net::~Net() { + if (impl_ != nullptr) { + impl_->erase_net(); + auto pnet = impl_->pnet(); + if (pnet != nullptr) { + impl_->set_pnet(nullptr); + } + } + impl_ = nullptr; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/net.h b/mindspore/lite/src/expression/net.h new file mode 100644 index 00000000000..f5525173e3e --- /dev/null +++ b/mindspore/lite/src/expression/net.h @@ -0,0 +1,114 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +class Net; +class NetImpl; +namespace lite { +#define REG(_name) Register(_name, #_name) + +class ExportSession; + +class Net : public Node { + public: + Net() = default; + virtual ~Net(); + explicit Net(std::string name) : Node(name) {} + std::vector construct(const std::vector &inputs) override; + std::vector operator()(const std::vector &inputs) override; + std::vector operator()(const std::initializer_list &&inputs) override; + void update_name(std::string name) override; + Net *TrainNet(Node *optimizer, Node *loss_fn, const std::vector &inputs); + Net *TrainNet(Node *optimizer, const std::vector &inputs); + void PrintDot() { EXPR::PrintDot(output_); } + + void PushOutput(EXPR *e) { output_.push_back(e); } + void PushInput(EXPR *e) { input_.push_back(e); } + void SetRealOutput() { real_output_ = output_; } + std::set trainable_params() override; + std::vector Sort(); + int BuildGrad(Node *optimizer); + int BuildGrad(Node *optimizer, std::set learnable); + std::unique_ptr MakeMs(); + std::unique_ptr MakeMs(std::string file_name); + schema::MetaGraph *meta_graph() { return meta_graph_; } + int Save(const schema::MetaGraphT &graph, const std::string filename); + void set_impl(std::shared_ptr impl) { impl_ = impl; } + const std::vector InputShape(int idx); + const std::vector OutputShape(int idx); + + protected: + std::vector add(const std::vector &input); + void Register(Node *node, std::string &&name) { + if (node != nullptr) { + PushOp(node); + node->update_name(name); + } + } + + private: + friend mindspore::Net; + std::unordered_set visited_; + std::map> outmap_; // outputs per expression + std::map> inmap_; // inputs per expression + std::vector output_; // network output expression + std::vector real_output_; // network output for export + std::vector input_; // network input expression + schema::MetaGraph *meta_graph_; // imported meta_graph + std::unique_ptr session_; // export session + std::unique_ptr optimizer_; + void TopoSortUtil(Node *v, std::stack *stack); + void CreateOutputMap(std::vector vec, std::map> *outmap); + std::shared_ptr impl_; +}; + +class NetWithLoss : public Net { + public: + NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) { + REG(net_); + REG(loss_fn_); + loss_fn_->set_name("_loss_fn"); + } + std::vector construct(const std::vector &inputs) { + auto input = inputs[0]; + auto label = inputs[1]; + auto x = (*net_)({input}); + x = (*loss_fn_)({x[0], label}); + return {x.front()}; + } + + private: + Net *net_{nullptr}; + Node *loss_fn_{nullptr}; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ diff --git a/mindspore/lite/src/expression/node.cc b/mindspore/lite/src/expression/node.cc new file mode 100644 index 00000000000..9c6a68511dd --- /dev/null +++ b/mindspore/lite/src/expression/node.cc @@ -0,0 +1,271 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/expression/node.h" +#include "src/expression/ops.h" +#include "src/expression/export.h" +#include "src/runtime/infer_manager.h" +#include "src/common/utils.h" +#include "src/cxx_api/expression/net_impl.h" + +namespace mindspore { +namespace lite { +int Node::name_id; + +std::vector Node::construct(const std::vector &inputs) { + if (inputs.size() >= expr()->params().size()) { + expr()->set_params(inputs); + } else { + for (std::size_t i = 0; i < inputs.size(); i++) { + expr()->set_params(i, inputs[i]); + } + } + auto ret = InferShape(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "error infershape for node " << name(); + return {}; + } + std::vector res(expr_.size()); + (void)std::transform(expr_.begin(), expr_.end(), res.begin(), [](const EXPR &e) { return const_cast(&e); }); + return res; +} + +std::vector Node::Grad(EXPR *expr) { + MS_LOG(ERROR) << name() << " (" << schema::EnumNamePrimitiveType(primitive()) << ") does not have grad defined"; + return {}; +} + +int Node::CreateTensorFromExpr(const std::vector &expr, std::vector *tensors, bool is_input) { + MS_ASSERT(tensors != nullptr); + int ret = RET_OK; + for (auto e : expr) { + // Tensor -> TensorC + if (is_input && e->node()->primitive() == schema::PrimitiveType_Depend) { + continue; + } + auto type = (e->node()->primitive() != schema::PrimitiveType_NONE) ? Category::VAR : Category::CONST_TENSOR; + auto t = std::make_unique(e->data_type(), e->dims(), (mindspore::Format)e->format(), type); + if (t == nullptr) { + ret = RET_NULL_PTR; + break; + } + // copy data if any + if (type == Category::CONST_TENSOR) { + void *dst = t->MutableData(); + if (dst == nullptr) { + ret = RET_NULL_PTR; + break; + } + if (e->node()->data() && (e->node()->data()->data().size() > 0)) { + uint8_t *src = e->node()->data()->data().data(); + memcpy(dst, src, t->Size()); + } + } + tensors->push_back(t.release()); + } + return ret; +} + +void Node::FreeAllTensors(std::vector *tensors) { + MS_ASSERT(tensors != nullptr); + for (auto &t : *tensors) { + delete t; + } + tensors->clear(); +} + +int Node::InferShape() { + auto ret = RET_OK; + std::vector in_tensors; + std::vector out_tensors; + // build in \ out tensors + ret = CreateTensorFromExpr(expr()->params(), &in_tensors, true); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Failed in create in tensors"; + FreeAllTensors(&in_tensors); + return RET_ERROR; + } + std::vector expr(expr_.size()); + (void)std::transform(expr_.begin(), expr_.end(), expr.begin(), [](const EXPR &e) { return const_cast(&e); }); + ret = CreateTensorFromExpr(expr, &out_tensors); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Failed in create out tensors"; + FreeAllTensors(&in_tensors); + FreeAllTensors(&out_tensors); + return RET_ERROR; + } + // Do infer Shape + ret = KernelInferShape(in_tensors, out_tensors, OpParam()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "failed in infer shape for " << name(); + FreeAllTensors(&in_tensors); + FreeAllTensors(&out_tensors); + return RET_ERROR; + } + // copy infer shape into expr + for (uint32_t i = 0; i < expr_.size(); i++) { + auto e = &expr_.at(i); + auto o = out_tensors.at(i); + e->set_format((o->format())); + e->set_data_type(o->data_type()); + e->SetDims(o->shape()); + } + // cleanup + FreeAllTensors(&in_tensors); + FreeAllTensors(&out_tensors); + + return ret; +} + +EXPR *Node::CreateWeights(std::vector dims, TypeId data_type, int format, Param::Mode mode, std::string name) { + auto weights = new (std::nothrow) InputM(dims); + if (weights == nullptr) { + MS_LOG(ERROR) << "Cannot allocate weights"; + return nullptr; + } + weights->set_name(this->name() + "/" + name); + int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + weights->data()->SetSize(size); + weights->data()->Fill(mode); + PushOp(weights); + return weights->expr(); +} + +Node *Node::CreateConstTensor(int index, std::vector dims, TypeId data_type, int format, std::string name, + const void *data) { + auto tensor = NN::Input(dims, data_type, format); + int elem_size = DataTypeSize(data_type); + tensor->set_name(this->name() + "/" + name); + int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()) * elem_size; + tensor->data()->SetSize(size); + tensor->data()->Copy(reinterpret_cast(data), size); + expr()->set_params(index, tensor->expr()); + PushOp(tensor); + return tensor; +} + +int Node::MakeEntry(ExportSession *session) { + std::vector input_idx; + std::vector output_idx; + std::vector empty; + if (primitive() == schema::PrimitiveType_Depend) return RET_OK; + // create node input + size_t inputs = InputsNum(); + for (size_t i = 0; i < inputs; i++) { + EXPR *ex = expr()->GetInput(i); + if (ex->node()->primitive() == schema::PrimitiveType_Depend) continue; + uint32_t id = session->GetOutput(ex); + input_idx.push_back(id); + } + size_t outputs = OutputsNum(); + size_t last_id = session->meta_graph()->allTensors.size(); + int type = (primitive() == schema::PrimitiveType_NONE) ? NodeType_ValueNode : NodeType_CNode; + auto data = (type == NodeType_ValueNode) ? this->data()->data() : empty; + if (data.empty()) type = NodeType_CNode; // input is Cnode !!? + int idx = 0; + for (size_t i = 0; i < outputs; i++) { + if (session->IsToDependOnly(expr(i))) continue; + output_idx.push_back(last_id + idx); + session->UpdateOutput(expr(i), last_id + idx); + auto odims = dims(i); + auto data_type = expr(i)->data_type(); + auto format = expr(i)->format(); + std::string footer = (i > 0) ? ("-" + std::to_string(i)) : ""; + auto otensor = CreateTensor(name() + footer, type, data_type, odims, format, data); + std::cout << "tensor -" << last_id + idx << ": " << name() + footer << std::endl; + idx++; + session->meta_graph()->allTensors.emplace_back(std::move(otensor)); + } + if (primitive() != schema::PrimitiveType_NONE) { + if (output_idx.size() == 0) { + return RET_OK; + } + auto cnode = CreateCNode(input_idx, output_idx); + + auto ret = UnPopulate(cnode); + if (ret != RET_OK) { + MS_LOG(ERROR) << "failed to populate cnode"; + return RET_ERROR; + } + session->meta_graph()->nodes.emplace_back(std::move(cnode)); + } + + return RET_OK; +} + +std::unique_ptr Node::CreateCNode(std::vector inputIndex, std::vector outputIndex) { + auto cnode = std::make_unique(); + cnode->primitive = std::make_unique(); + cnode->primitive->value.type = primitive(); + cnode->name = name(); + cnode->inputIndex = inputIndex; + cnode->outputIndex = outputIndex; + return cnode; +} + +int Node::UnPopulate(const std::unique_ptr &cnode) { + MS_LOG(ERROR) << "Node " << schema::EnumNamePrimitiveType(primitive()) << " cannot be exported"; + return RET_ERROR; +} + +std::unique_ptr Node::CreateTensor(std::string name, int type, int data_type, + const std::vector dims, int format, + const std::vector &data) { + auto tensorT = std::make_unique(); + tensorT->nodeType = type; + tensorT->dims = dims; + tensorT->format = static_cast(format); + tensorT->name = name; + tensorT->refCount = 0; + tensorT->offset = 0; + tensorT->dataType = data_type; + tensorT->data = data; + tensorT->enableHuffmanCode = false; + if (tensorT->nodeType == mindspore::lite::NodeType_ValueNode) { + tensorT->data = data; + } + return tensorT; +} + +int Node::SetOutputs(int num) { + EXPR e(this); + e.SetSize(0); + for (auto i = expr_.size(); i < static_cast(num); i++) { + expr_.emplace_back(e); + } + return RET_OK; +} + +Node::~Node() { + for (auto &op : ops_) { + delete op; + } + ops_.clear(); + if (impl_ != nullptr) { + impl_->set_node(nullptr); + auto pnode = impl_->pnode(); + if (pnode != nullptr) { + impl_->set_pnode(nullptr); + delete pnode; + } + } + impl_ = nullptr; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/node.h b/mindspore/lite/src/expression/node.h new file mode 100644 index 00000000000..b6a820e090b --- /dev/null +++ b/mindspore/lite/src/expression/node.h @@ -0,0 +1,156 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include "src/expression/export.h" +#include "inner/model_generated.h" +#include "src/expression/param.h" +#include "src/expression/expr.h" +#include "src/tensor.h" +#include "nnacl/op_base.h" + +namespace mindspore { +class NodeImpl; +namespace schema { +struct TensorT; +struct CNodeT; +} // namespace schema + +namespace lite { +class Node { + public: + const std::string kGradName = "Gradients"; + explicit Node(const std::string name) : opParam_(nullptr), name_(name) { expr_.emplace_back(this); } + virtual ~Node(); + Node() : Node("") {} + explicit Node(Node *node) : Node(*node) {} + EXPR *create(std::string name) { + name_ = name; + return &expr_[0]; + } + virtual std::vector operator()(const std::vector &inputs) { + auto x = construct(inputs); + return x; + } + virtual std::vector operator()(const std::initializer_list &&inputs) { + std::vector vec = inputs; + auto x = construct(vec); + return x; + } + virtual std::vector operator()(const std::initializer_list &inputs) { + std::vector vec = inputs; + auto x = construct(vec); + return x; + } + void set_primitive(schema::PrimitiveType primitive) { + primitive_ = primitive; + if (OpParam() != nullptr) opParam_->type_ = primitive_; + } + schema::PrimitiveType primitive() { return primitive_; } + virtual std::vector construct(const std::vector &inputs); + std::string name() { return name_; } + void set_name(std::string name) { name_ = name; } + virtual void update_name(std::string name) { set_name(name + "/" + name_); } + size_t Load(std::string file_name, size_t offset = 0) { return offset; } + OpParameter *OpParam() const { return opParam_.get(); } + virtual void Add(Node *node) {} + virtual std::vector Clone(EXPR *grad, EXPR *weight) { return {}; } + void SetOpParam(std::shared_ptr opParam) { opParam_ = opParam; } + void SetOpParam(void *opParam) { opParam_.reset(reinterpret_cast(opParam), free); } + static std::string UniqueName(const std::string &name) { return name + "-" + std::to_string(name_id++); } + static std::string UniqueName(std::string &&name) { return name + "-" + std::to_string(name_id++); } + template + int CloneOpParam(std::shared_ptr opParam) { + auto t = reinterpret_cast(opParam.get()); + auto obj = new (std::nothrow) T(*t); // copy content + if (obj == nullptr) { + MS_LOG(ERROR) << "Cannot allocate obj"; + return RET_ERROR; + } + opParam_.reset(reinterpret_cast(obj)); + return RET_OK; + } + template + int CloneOpParam(OpParameter *opParam) { + auto t = reinterpret_cast(opParam); + auto obj = new (std::nothrow) T(*t); // copy content + if (obj == nullptr) { + MS_LOG(ERROR) << "Cannot allocate obj"; + return RET_ERROR; + } + opParam_.reset(reinterpret_cast(obj)); + return RET_OK; + } + virtual Param *weight() { return nullptr; } + EXPR *expr(int i) { return &expr_[i]; } + EXPR *expr() { return expr(0); } + std::vector inputs() { return expr()[0].params(); } + size_t InputsNum() { return expr()[0].params().size(); } + size_t OutputsNum() { return expr_.size(); } + EXPR *input(int idx) { return expr()[0].params().at(idx); } + EXPR *output(int idx) { return expr(idx); } + EXPR *CreateWeights(std::vector dims, TypeId data_type, int format, Param::Mode mode, std::string name); + Node *CreateConstTensor(int index, std::vector dims, TypeId data_type, int format, std::string name, + const void *data); + virtual std::vector Grad(EXPR *expr); + virtual Param *data() { return nullptr; } + bool IsLearn(Node *node) { return learnable_.find(node) != learnable_.end(); } + virtual void SetLearn() {} + virtual std::set trainable_params() { return learnable_; } + std::vector &dims() { return expr()->dims(); } + std::vector &dims(int i) { return expr(i)->dims(); } + // export + int MakeEntry(ExportSession *session); + void PushOp(Node *n) { ops_.push_back(n); } + virtual void AddNetOutput(std::vector *output) {} + int SetOutputs(int num); + std::shared_ptr opParam_; + void set_impl(std::shared_ptr impl) { impl_ = impl; } + + protected: + std::vector expr_; // hold outputs + std::vector ops_; // all nodes or subnets + int InferShape(); + void AddLearn(Node *node) { learnable_.insert(node); } + void AssignLearn(std::set &&learn) { learnable_ = learn; } + + std::unique_ptr CreateCNode(std::vector inputIndex, std::vector outputIndex); + virtual int UnPopulate(const std::unique_ptr &cnode); + std::unique_ptr CreateTensor(std::string name, int type, int data_type, + const std::vector dims, int format, + const std::vector &data); + + private: + int CreateTensorFromExpr(const std::vector &expr, std::vector *tensors, bool is_input = false); + void FreeAllTensors(std::vector *tensors); + static int name_id; + std::set learnable_; // set of nodes with learnable parameters + std::string name_; + schema::PrimitiveType primitive_; + std::shared_ptr impl_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ diff --git a/mindspore/lite/src/expression/ops.cc b/mindspore/lite/src/expression/ops.cc new file mode 100644 index 00000000000..629fa0aec40 --- /dev/null +++ b/mindspore/lite/src/expression/ops.cc @@ -0,0 +1,66 @@ + +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/expression/ops.h" +#include "src/expression/ops_utils.h" +#include "src/expression/param.h" +#include "include/api/cfg.h" +#include "src/expression/sequential.h" + +namespace mindspore { +namespace lite { +void InputM::SetUp(const std::vector &dims, TypeId data_type, int fmt) { + expr()->SetSize(C0NUM); + expr()->SetDims(dims); + expr()->set_data_type(data_type); + expr()->set_format(fmt); + set_primitive(schema::PrimitiveType_NONE); +} + +InputM::InputM(const std::vector &dims, TypeId data_type, int fmt) : Node() { SetUp(dims, data_type, fmt); } + +InputM::InputM(const schema::Tensor *tensor) : Node() { + std::vector dims(tensor->dims()->size()); + (void)std::transform(tensor->dims()->begin(), tensor->dims()->end(), dims.begin(), [](int32_t x) { return x; }); + SetUp(dims, static_cast(tensor->dataType()), tensor->format()); + if (tensor->name()) set_name(tensor->name()->str()); + if (tensor->data() != nullptr) data_.Copy(tensor->data()->data(), tensor->data()->size()); +} + +namespace NN { +Node *Input(const std::vector &dims, TypeId data_type, int fmt) { + auto i = new (std::nothrow) InputM(dims, data_type, fmt); + if (i == nullptr) { + MS_LOG(ERROR) << "Cannot allocate input expression "; + return nullptr; + } + return i; +} + +Net *Sequential() { + auto s = new (std::nothrow) mindspore::lite::Sequential(); + if (s == nullptr) { + MS_LOG(ERROR) << "Cannot allocate sequential expression"; + return nullptr; + } + return s; +} +}; // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops.h b/mindspore/lite/src/expression/ops.h new file mode 100644 index 00000000000..96a006f46e3 --- /dev/null +++ b/mindspore/lite/src/expression/ops.h @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ + +#include +#include +#include +#include "include/api/net.h" +#include "src/expression/cfg.h" +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class InputM : public Node { + public: + explicit InputM(const schema::Tensor *tensor); + explicit InputM(const std::vector &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC); + Param *data() override { return &data_; } + + private: + void SetUp(const std::vector &dims, TypeId data_type, int fmt); + Param data_; +}; +namespace NN { +Node *Conv2D(const ConvConfig &cfg); +Node *Relu(); +Node *Dense(const DenseConfig &cfg); +Node *Flatten(); +Node *Input(const std::vector &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC); +Node *Add(); +Node *Sub(); +Node *Div(); +Node *Mul(); +Node *Neg(); +Node *SoftmaxCrossEntropy(); +Net *Sequential(); +Node *Adam(std::set &&learn, const AdamConfig &cfg); + +Node *Softmax(int axis = -1); +Node *BatchNorm2D(int outp, float momentum = 0.1, float epsilon = 1e-5f); +Node *Sigmoid(); +Node *DropOut(float ration = 0.5); +Node *ReLU6(); +Node *Reshape(const std::vector &shape); +Node *ReduceMean(bool keep_dims, const std::vector &dims); +Node *ReduceSum(bool keep_dims, const std::vector &dims); +Node *Tile(const std::vector &multiples); +Node *MaxPool2D(const PoolingConfig &cfg); +Node *AvgPool2D(const PoolingConfig &cfg); +} // namespace NN +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ diff --git a/mindspore/lite/src/expression/ops/activation.cc b/mindspore/lite/src/expression/ops/activation.cc new file mode 100644 index 00000000000..b272a3b0512 --- /dev/null +++ b/mindspore/lite/src/expression/ops/activation.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/expression/ops/activation.h" +#include "nnacl/fp32/activation_fp32.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +ActM::ActM(schema::ActivationType type) : Node() { + auto op_param = malloc(sizeof(ActivationParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ActivationParameter"; + return; + } + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_Activation); + ActivationParameter *act_param = reinterpret_cast(opParam_.get()); + act_param->type_ = type; + act_param->alpha_ = 0.f; + act_param->min_val_ = 0.f; + act_param->max_val_ = 0.f; +} + +std::vector ActM::Grad(EXPR *yt) { + auto actGrad = new (std::nothrow) ActGradM(this); + if (actGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate activation grad node"; + return {}; + } + PushOp(actGrad); + auto param = reinterpret_cast(actGrad->OpParam()); + EXPR *ag = nullptr; + actGrad->expr()->SetSize(C2NUM); + if ((param->type_ == schema::ActivationType_SIGMOID) || (param->type_ == schema::ActivationType_TANH)) { + ag = (*actGrad)({output(0), yt}).front(); + } else if ((param->type_ == schema::ActivationType_HSWISH) || (param->type_ == schema::ActivationType_HSIGMOID) || + (param->type_ == schema::ActivationType_RELU6)) { + ag = (*actGrad)({yt, input(0)}).front(); + } else if (param->type_ == schema::ActivationType_GELU) { + actGrad->expr()->SetSize(C3NUM); + ag = (*actGrad)({yt, input(0), output(0)}).front(); + } else { + ag = (*actGrad)({yt, output(0)}).front(); + } + std::vector res = {ag}; + return res; +} + +int ActM::UnPopulate(const std::unique_ptr &cnode) { + auto act_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::ActivationT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate activation primitive"; + return RET_ERROR; + } + prim->activation_type = static_castactivation_type)>(act_param->type_); + prim->alpha = act_param->alpha_; + prim->min_val = act_param->min_val_; + prim->max_val = act_param->max_val_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Activation, ReturnNode); + +ActGradM::ActGradM(Node *node) { + CloneOpParam(node->OpParam()); + set_primitive(schema::PrimitiveType_ActivationGrad); + set_name(node->name() + "/" + kGradName + "/actGrad"); +} + +std::vector ActGradM::Grad(EXPR *yt) { return {}; } + +int ActGradM::UnPopulate(const std::unique_ptr &cnode) { + auto act_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::ActivationGradT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate activation grad primitive"; + return RET_ERROR; + } + prim->activation_type = static_castactivation_type)>(act_param->type_); + prim->alpha = act_param->alpha_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg regGrad(schema::PrimitiveType_ActivationGrad, ReturnNode); +namespace NN { +Node *ReLU6() { + auto r = new (std::nothrow) ActM(schema::ActivationType_RELU6); + if (r == nullptr) { + MS_LOG(ERROR) << "Cannot allocate relu6"; + return nullptr; + } + r->set_name(Node::UniqueName("ReLU6")); + return r; +} +Node *Sigmoid() { + auto s = new (std::nothrow) ActM(schema::ActivationType_SIGMOID); + if (s == nullptr) { + MS_LOG(ERROR) << "Cannot allocate sigmoid"; + return nullptr; + } + s->set_name(Node::UniqueName("Sigmoid")); + return s; +} +Node *Relu() { + auto r = new (std::nothrow) ActM(schema::ActivationType_RELU); + if (r == nullptr) { + MS_LOG(ERROR) << "Cannot allocate relu"; + return nullptr; + } + r->set_name(r->UniqueName("Relu")); + return r; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/activation.h b/mindspore/lite/src/expression/ops/activation.h new file mode 100644 index 00000000000..14271c091d5 --- /dev/null +++ b/mindspore/lite/src/expression/ops/activation.h @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class ActM : public Node { + public: + ActM() = default; + explicit ActM(schema::ActivationType type); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class ActGradM : public Node { + public: + ActGradM() : Node() {} // for Import + explicit ActGradM(Node *act); // for Grad + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ diff --git a/mindspore/lite/src/expression/ops/adam.cc b/mindspore/lite/src/expression/ops/adam.cc new file mode 100644 index 00000000000..330f9071a06 --- /dev/null +++ b/mindspore/lite/src/expression/ops/adam.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/adam.h" +#include +#include +#include +#include "src/expression/ops.h" +#include "src/expression/ops/assign.h" +#include "src/expression/ops/arithmetic.h" +#include "nnacl/fp32_grad/optimizer.h" +#include "include/api/net.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace NN { +Node *Adam(std::shared_ptr learn, const AdamConfig &cfg) { + auto lite_node = lite::NN::Adam(std::move(learn->set_), cfg); + return NodeImpl::Connect(lite_node); +} +} // namespace NN + +namespace lite { +std::vector AdamM::Clone(EXPR *grad, EXPR *weight) { + auto adam = new (std::nothrow) AdamM(); + if (adam == nullptr) { + MS_LOG(ERROR) << "Cannot allocate adam"; + return {}; + } + adam->set_name("optimizer-Adam"); + adam->CloneOpParam(OpParam()); + adam->update_name(weight->node()->name()); + adam->set_primitive(primitive()); + adam->expr()->SetSize(C10NUM); + // setup weight and momentum + adam->expr()->set_params(C0NUM, weight); + auto dims = grad->dims(); + auto m = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "m"); + adam->expr()->set_params(C1NUM, m); + auto v = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "v"); + adam->expr()->set_params(C2NUM, v); + // copy parameters + for (int i = C3NUM; i < C9NUM; i++) { + adam->expr()->set_params(i, this->input(i)); + } + adam->expr()->set_params(C9NUM, grad); + return (*adam)(adam->inputs()); +} + +AdamM::AdamM(std::set &&learn, const AdamConfig &cfg) { + auto op_param = reinterpret_cast(malloc(sizeof(AdamParameter))); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ActivationParameter"; + return; + } + AssignLearn(std::move(learn)); + memset(op_param, 0, sizeof(AdamParameter)); + op_param->use_nesterov_ = cfg.use_nesterov_; + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_Adam); + set_name("optimizer-Adam"); + // Adam Network + expr()->SetSize(C10NUM); + auto assign1 = new (std::nothrow) AssignM(0); + if (assign1 == nullptr) { + MS_LOG(ERROR) << "Cannot allocate assign"; + return; + } + PushOp(assign1); + auto assign2 = new (std::nothrow) AssignM(0); + if (assign2 == nullptr) { + MS_LOG(ERROR) << "Cannot allocate assign"; + return; + } + PushOp(assign2); + auto mul1 = NN::Mul(); + if (mul1 == nullptr) { + MS_LOG(ERROR) << "Cannot allocate mul"; + return; + } + PushOp(mul1); + auto mul2 = NN::Mul(); + if (mul2 == nullptr) { + MS_LOG(ERROR) << "Cannot allocate mul"; + return; + } + PushOp(mul2); + auto tmp = 1.0f; + mul1->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-power", &tmp); + mul1->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-data", &cfg.beta1_); + auto o1 = (*mul1)({}); + assign1_ = (*assign1)({mul1->input(0), o1.front()}).front(); + mul2->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-power", &tmp); + mul2->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-data", &cfg.beta2_); + auto o2 = (*mul2)({}); + assign2_ = (*assign2)({mul2->input(0), o2.front()}).front(); + expr()->set_params(C3NUM, o1.front()); + expr()->set_params(C4NUM, o2.front()); + CreateConstTensor(C5NUM, {1}, kNumberTypeFloat32, KHWC, "learning-rate", &cfg.learning_rate_); + CreateConstTensor(C6NUM, {1}, kNumberTypeFloat32, KHWC, "beta1", &cfg.beta1_); + CreateConstTensor(C7NUM, {1}, kNumberTypeFloat32, KHWC, "beta2", &cfg.beta2_); + CreateConstTensor(C8NUM, {1}, kNumberTypeFloat32, KHWC, "epsilon", &cfg.eps_); +} + +int AdamM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::AdamT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate " << cnode->name; + return RET_ERROR; + } + prim->use_nesterov = param->use_nesterov_; + prim->use_locking = false; + cnode->primitive->value.value = prim; + return RET_OK; +} + +namespace NN { +Node *Adam(std::set &&learn, const AdamConfig &cfg) { + auto a = new (std::nothrow) AdamM(std::move(learn), cfg); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate adam"; + return nullptr; + } + return a; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/adam.h b/mindspore/lite/src/expression/ops/adam.h new file mode 100644 index 00000000000..ee1e2f56282 --- /dev/null +++ b/mindspore/lite/src/expression/ops/adam.h @@ -0,0 +1,48 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ + +#include +#include +#include +#include "include/api/net.h" +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class AdamM : public Node { + public: + AdamM() = default; + AdamM(std::set &&learn, const AdamConfig &cfg); + std::vector Clone(EXPR *grad, EXPR *weight) override; + void AddNetOutput(std::vector *output) override { + output->push_back(assign1_); + output->push_back(assign2_); + } + int UnPopulate(const std::unique_ptr &cnode) override; + + private: + float weight_decay_; + float loss_scale_; + EXPR *assign1_{nullptr}; + EXPR *assign2_{nullptr}; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ diff --git a/mindspore/lite/src/expression/ops/addn.cc b/mindspore/lite/src/expression/ops/addn.cc new file mode 100644 index 00000000000..bd614f1d5d5 --- /dev/null +++ b/mindspore/lite/src/expression/ops/addn.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/addn.h" + +namespace mindspore { +namespace lite { +AddN::AddN(int dummy) { + auto op_param = calloc(1, sizeof(OpParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter "; + return; + } + set_name(UniqueName("addN")); + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_AddN); +} + +int AddN::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::AddNT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/addn.h b/mindspore/lite/src/expression/ops/addn.h new file mode 100644 index 00000000000..3ed96319146 --- /dev/null +++ b/mindspore/lite/src/expression/ops/addn.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ + +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class AddN : public Node { + public: + explicit AddN(int dummy); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ diff --git a/mindspore/lite/src/expression/ops/arithmetic.cc b/mindspore/lite/src/expression/ops/arithmetic.cc new file mode 100644 index 00000000000..988c151b038 --- /dev/null +++ b/mindspore/lite/src/expression/ops/arithmetic.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/arithmetic.h" +#include +#include "src/expression/ops/reduce.h" +#include "src/expression/ops/reshape.h" +#include "src/expression/ops_utils.h" +#include "src/expression/ops/arithmetic_self.h" +#include "src/expression/ops.h" +#include "nnacl/arithmetic.h" +#include "src/expression/import.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +// Common Arithmetic Functionality +ArithmeticM::ArithmeticM(schema::PrimitiveType type) : Node() { + auto op_param = malloc(sizeof(ArithmeticParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ActivationParameter"; + return; + } + SetOpParam(op_param); + expr()->SetSize(C2NUM); + set_primitive(type); +} + +std::vector ArithmeticM::binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy) { + auto shape_of_x = x->dims(); + auto shape_of_y = y->dims(); + auto reduce_dx = dx; + auto reduce_dy = dy; + auto rx = (BroadcastGradientArgs(shape_of_x, shape_of_y))(); + if (rx[0].size()) { + auto reduce_sum = NN::ReduceSum(false, rx[0]); + PushOp(reduce_sum); + reduce_dx = (*reduce_sum)({reduce_dx}).front(); + auto reshape = NN::Reshape(shape_of_x); + PushOp(reshape); + reduce_dx = (*reshape)({reduce_dx}).front(); + } + if (rx[1].size()) { + auto reduce_sum = NN::ReduceSum(false, rx[1]); + PushOp(reduce_sum); + reduce_dy = (*reduce_sum)({reduce_dy}).front(); + auto reshape = NN::Reshape(shape_of_y); + PushOp(reshape); + reduce_dy = (*reshape)({reduce_dy}).front(); + } + std::vector out = {reduce_dx, reduce_dy}; + return out; +} + +// Add Op +AddM::AddM(int dummy) : ArithmeticM(schema::PrimitiveType_AddFusion) { set_name(UniqueName("Add")); } + +std::vector AddM::Grad(EXPR *yt) { return binop_grad_common(input(0), input(1), yt, yt); } + +int AddM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::AddFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate prim"; + return RET_ERROR; + } + prim->activation_type = schema::ActivationType_NO_ACTIVATION; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg AddReg(schema::PrimitiveType_AddFusion, ReturnNode); + +// Div op +DivM::DivM(int dummy) : ArithmeticM(schema::PrimitiveType_RealDiv) { set_name(UniqueName("RealDiv")); } +std::vector DivM::Grad(EXPR *yt) { + auto x = input(0); + auto y = input(1); + auto o = output(0); + auto div_op = NN::Div(); + if (div_op == nullptr) { + MS_LOG(ERROR) << "Cannot allocate div_op"; + return {}; + } + PushOp(div_op); + auto neg_op = NN::Neg(); + if (neg_op == nullptr) { + MS_LOG(ERROR) << "Cannot allocate neg_op"; + return {}; + } + PushOp(neg_op); + auto mul_op = NN::Mul(); + if (mul_op == nullptr) { + MS_LOG(ERROR) << "Cannot allocate mul_op"; + return {}; + } + PushOp(mul_op); + auto bc_x = (*div_op)({yt, y}).front(); + auto bc_y = (*neg_op)((*mul_op)({bc_x, o})).front(); + return binop_grad_common(x, y, bc_x, bc_y); +} +int DivM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::RealDivT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} +static ImportReg DivReg(schema::PrimitiveType_DivFusion, ReturnNode); + +// Mul op +MulM::MulM(int dummy) : ArithmeticM(schema::PrimitiveType_MulFusion) { set_name(UniqueName("Mul")); } + +std::vector MulM::Grad(EXPR *yt) { + auto mul_dx = NN::Mul(); + if (mul_dx == nullptr) { + MS_LOG(ERROR) << "Cannot allocate mul dx"; + return {}; + } + PushOp(mul_dx); + auto mul_dy = NN::Mul(); + if (mul_dy == nullptr) { + MS_LOG(ERROR) << "Cannot allocate mul_dy"; + return {}; + } + PushOp(mul_dy); + auto x = input(0); + auto y = input(1); + auto bc_dx = (*mul_dx)({y, yt}).front(); + auto bc_dy = (*mul_dy)({x, yt}).front(); + return binop_grad_common(x, y, bc_dx, bc_dy); +} + +int MulM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::MulFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate prim"; + return RET_ERROR; + } + prim->activation_type = schema::ActivationType_NO_ACTIVATION; + cnode->primitive->value.value = prim; + return RET_OK; +} +static ImportReg MulReg(schema::PrimitiveType_MulFusion, ReturnNode); + +// Sub op +SubM::SubM(int dummy) : ArithmeticM(schema::PrimitiveType_SubFusion) { set_name(UniqueName("Sub")); } + +std::vector SubM::Grad(EXPR *yt) { + auto x = input(0); + auto y = input(1); + auto neg = NN::Neg(); + if (neg == nullptr) { + MS_LOG(ERROR) << "Cannot allocate neg"; + return {}; + } + PushOp(neg); + auto neg_grad = (*neg)({yt}).front(); + return binop_grad_common(x, y, yt, neg_grad); +} +int SubM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::SubFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate prim"; + return RET_ERROR; + } + prim->activation_type = schema::ActivationType_NO_ACTIVATION; + cnode->primitive->value.value = prim; + return RET_OK; +} +static ImportReg SubReg(schema::PrimitiveType_SubFusion, ReturnNode); + +namespace NN { +Node *Add() { + auto a = new (std::nothrow) AddM(0); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate a"; + return nullptr; + } + return a; +} +Node *Sub() { + auto a = new (std::nothrow) SubM(0); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate a"; + return nullptr; + } + return a; +} + +Node *Mul() { + auto a = new (std::nothrow) MulM(0); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate a"; + return nullptr; + } + return a; +} +Node *Div() { + auto a = new (std::nothrow) DivM(0); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate a"; + return nullptr; + } + return a; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/arithmetic.h b/mindspore/lite/src/expression/ops/arithmetic.h new file mode 100644 index 00000000000..b15092457ec --- /dev/null +++ b/mindspore/lite/src/expression/ops/arithmetic.h @@ -0,0 +1,75 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class ArithmeticM : public Node { + public: + ArithmeticM() = default; + explicit ArithmeticM(schema::PrimitiveType type); + + protected: + std::vector binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy); +}; + +class AddM : public ArithmeticM { + public: + AddM() = default; + explicit AddM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class DivM : public ArithmeticM { + public: + DivM() = default; + explicit DivM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class MulM : public ArithmeticM { + public: + MulM() = default; + explicit MulM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class SubM : public ArithmeticM { + public: + SubM() = default; + explicit SubM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; +namespace NN { +Node *Add(); +Node *Sub(); +Node *Mul(); +Node *Div(); +} // namespace NN +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ diff --git a/mindspore/lite/src/expression/ops/arithmetic_self.cc b/mindspore/lite/src/expression/ops/arithmetic_self.cc new file mode 100644 index 00000000000..e5a84f7503c --- /dev/null +++ b/mindspore/lite/src/expression/ops/arithmetic_self.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/arithmetic_self.h" +#include +#include "src/expression/ops_utils.h" +#include "src/expression/ops.h" +#include "nnacl/arithmetic_self_parameter.h" +#include "src/expression/import.h" + +namespace mindspore { +namespace lite { +// Common Arithmetic Self Functionality +ArithmeticSelfM::ArithmeticSelfM(schema::PrimitiveType type) : Node() { + auto op_param = malloc(sizeof(ArithmeticSelfParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ArithmeticSelfParameter"; + return; + } + memset(op_param, 0, sizeof(ArithmeticSelfParameter)); + SetOpParam(op_param); + expr()->SetSize(C1NUM); + set_primitive(type); +} + +// NEG OP +NegM::NegM(int dummy) : ArithmeticSelfM(schema::PrimitiveType_NegGrad) { set_name(UniqueName("Neg")); } + +std::vector NegM::Grad(EXPR *yt) { + auto grad_neg = new (std::nothrow) NegM(0); + if (grad_neg == nullptr) { + MS_LOG(ERROR) << "Cannot allocate neg gradient"; + return {}; + } + return (*grad_neg)({yt}); +} + +int NegM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::NegT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +namespace NN { +Node *Neg() { + auto a = new (std::nothrow) NegM(0); + if (a == nullptr) { + MS_LOG(ERROR) << "Cannot allocate neg node"; + return nullptr; + } + return a; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/arithmetic_self.h b/mindspore/lite/src/expression/ops/arithmetic_self.h new file mode 100644 index 00000000000..e64ba024005 --- /dev/null +++ b/mindspore/lite/src/expression/ops/arithmetic_self.h @@ -0,0 +1,46 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class ArithmeticSelfM : public Node { + public: + explicit ArithmeticSelfM(schema::PrimitiveType type); + + protected: +}; + +class NegM : public ArithmeticSelfM { + public: + explicit NegM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +namespace NN { +Node *Neg(); +} +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ diff --git a/mindspore/lite/src/expression/ops/assign.cc b/mindspore/lite/src/expression/ops/assign.cc new file mode 100644 index 00000000000..acf1095028a --- /dev/null +++ b/mindspore/lite/src/expression/ops/assign.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/assign.h" +#include +#include "nnacl/reshape_parameter.h" +#include "src/expression/import.h" + +namespace mindspore { +namespace lite { +AssignM::AssignM(int dummy) { + auto op_param = calloc(1, sizeof(OpParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ReshapeParameter"; + return; + } + expr()->SetSize(C2NUM); + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_Assign); + set_name(UniqueName("Assign")); +} + +int AssignM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::AssignT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode); + +namespace NN { +Node *Assign() { + auto node = new (std::nothrow) AssignM(0); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate node"; + return nullptr; + } + node->set_name(Node::UniqueName("Assign")); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/assign.h b/mindspore/lite/src/expression/ops/assign.h new file mode 100644 index 00000000000..0dfd2c67982 --- /dev/null +++ b/mindspore/lite/src/expression/ops/assign.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class AssignM : public Node { + public: + AssignM() = default; + explicit AssignM(int dummy); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ diff --git a/mindspore/lite/src/expression/ops/batchnorm.cc b/mindspore/lite/src/expression/ops/batchnorm.cc new file mode 100644 index 00000000000..6d6dfb0521a --- /dev/null +++ b/mindspore/lite/src/expression/ops/batchnorm.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/batchnorm.h" +#include +#include "nnacl/batchnorm_parameter.h" +#include "nnacl/fp32_grad/batch_norm.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +BatchNorm2dM::BatchNorm2dM(int outp, float momentum, float epsilon) { + constexpr int bn_inputs = 5; + constexpr int bn_outputs = 5; + + auto op_param = calloc(1, sizeof(BatchNormParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate BatchNormParameter"; + return; + } + expr()->SetSize(bn_inputs); + set_name(UniqueName("BatchNorm2D")); + auto bn_param = reinterpret_cast(op_param); + bn_param->channel_ = outp; + bn_param->momentum_ = momentum; + bn_param->epsilon_ = epsilon; + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_FusedBatchNorm); + std::vector dims = {outp}; + auto scale = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "scale"); + expr()->set_params(C1NUM, scale); + auto offset = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "offset"); + expr()->set_params(C2NUM, offset); + auto mean = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "mean"); + expr()->set_params(C3NUM, mean); + auto var = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "var"); + expr()->set_params(C4NUM, var); + SetOutputs(bn_outputs); + SetLearn(); +} + +BatchNorm2dGradM::BatchNorm2dGradM(BatchNorm2dM *bn_node) : Node() { + auto op_param = calloc(1, sizeof(BNGradParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate BNGradParameter"; + return; + } + expr()->SetSize(C6NUM); + set_name(bn_node->name() + "/" + kGradName + "/bnGrad"); + auto bn_grad_param = reinterpret_cast(op_param); + auto bn_param = reinterpret_cast(bn_node->OpParam()); + bn_param->is_training_ = true; + bn_grad_param->epsilon_ = bn_param->epsilon_; + bn_grad_param->is_training_ = true; + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_BatchNormGrad); + EXPR e(this); + e.SetSize(0); + // Dgamma + expr_.emplace_back(e); + // Doffset + expr_.emplace_back(e); +} + +std::vector BatchNorm2dM::Grad(EXPR *yt) { + auto bn_grad_node = new (std::nothrow) BatchNorm2dGradM(this); + if (bn_grad_node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate batchnorm grad"; + return {}; + } + PushOp(bn_grad_node); + auto bn_grad = (*bn_grad_node)({yt, input(0), output(1), output(3), output(4), output(2)}); + return bn_grad; +} + +int BatchNorm2dM::UnPopulate(const std::unique_ptr &cnode) { + auto bn_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::FusedBatchNormT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->epsilon = bn_param->epsilon_; + prim->momentum = bn_param->momentum_; + prim->mode = (bn_param->is_training_ == false) ? 0 : 1; + cnode->primitive->value.value = prim; + return RET_OK; +} + +void BatchNorm2dM::SetLearn() { + AddLearn(input(C1NUM)->node()); + AddLearn(input(C2NUM)->node()); +} + +int BatchNorm2dGradM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::BatchNormGradT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->epsilon = param->epsilon_; + prim->is_training = param->is_training_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_FusedBatchNorm, ReturnNode); +namespace NN { +Node *BatchNorm2D(int outp, float momentum, float epsilon) { + auto node = new (std::nothrow) BatchNorm2dM(outp, momentum, epsilon); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate node"; + return nullptr; + } + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/batchnorm.h b/mindspore/lite/src/expression/ops/batchnorm.h new file mode 100644 index 00000000000..2891f35a874 --- /dev/null +++ b/mindspore/lite/src/expression/ops/batchnorm.h @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class BatchNorm2dM : public Node { + public: + BatchNorm2dM() = default; + BatchNorm2dM(int outp, float momentum, float epsilon); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + void SetLearn() override; +}; + +class BatchNorm2dGradM : public Node { + public: + explicit BatchNorm2dGradM(BatchNorm2dM *bn_node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ diff --git a/mindspore/lite/src/expression/ops/biasadd.cc b/mindspore/lite/src/expression/ops/biasadd.cc new file mode 100644 index 00000000000..bb1d0a72d2f --- /dev/null +++ b/mindspore/lite/src/expression/ops/biasadd.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/biasadd.h" +#include "src/expression/ops/transpose.h" +#include "nnacl/arithmetic.h" +#include "src/expression/import.h" + +namespace mindspore { +namespace lite { +BiasAddM::BiasAddM(Format data_format) { + auto op_param = calloc(1, sizeof(ArithmeticParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ConvParameter"; + return; + } + auto bias_param = reinterpret_cast(op_param); + SetOpParam(bias_param); + set_primitive(schema::PrimitiveType_BiasAdd); +} + +std::vector BiasAddM::construct(const std::vector &inputs) { + auto x = Node::construct(inputs); + AddLearn(inputs.at(C1NUM)->node()); + return x; +} + +void BiasAddM::SetLearn() { AddLearn(input(C1NUM)->node()); } + +std::vector BiasAddM::Grad(EXPR *yt) { + auto in = yt; + if (yt->format() != NHWC && yt->dims().size() == C4NUM) { + in = TransposeM::TransposeCHW2HWC(yt); + in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); + PushOp(in->node()); + } + auto grad_node = new (std::nothrow) BiasAddGradM(*this); + if (grad_node == nullptr) { + MS_LOG(ERROR) << "Cannon allocate Bias Grad"; + return {}; + } + PushOp(grad_node); + auto bias_grad = (*grad_node)({in}); + return {in, bias_grad.front()}; +} +int BiasAddM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::BiasAddT; + if (prim == nullptr) { + MS_LOG(ERROR) << "cannot allocate prim"; + return RET_ERROR; + } + prim->format = static_cast(KHWC); + cnode->primitive->value.value = prim; + return RET_OK; +} + +BiasAddGradM::BiasAddGradM(const BiasAddM &bias) { + auto op_param = calloc(1, sizeof(OpParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate op_param"; + } + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_BiasAddGrad); +} + +int BiasAddGradM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::BiasAddGradT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_BiasAdd, ReturnNode); + +namespace NN {} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/biasadd.h b/mindspore/lite/src/expression/ops/biasadd.h new file mode 100644 index 00000000000..bb23e0ff220 --- /dev/null +++ b/mindspore/lite/src/expression/ops/biasadd.h @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ + +#include +#include +#include "src/expression/net.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class BiasAddM : public Node { + public: + BiasAddM() = default; + explicit BiasAddM(Format data_format); + std::vector construct(const std::vector &inputs) override; + std::vector Grad(EXPR *yt) override; + int UnPopulate(const std::unique_ptr &cnode) override; + void SetLearn() override; +}; + +class BiasAddGradM : public Node { + public: + explicit BiasAddGradM(const BiasAddM &bias); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ diff --git a/mindspore/lite/src/expression/ops/conv.cc b/mindspore/lite/src/expression/ops/conv.cc new file mode 100644 index 00000000000..aedbc34cf77 --- /dev/null +++ b/mindspore/lite/src/expression/ops/conv.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/conv.h" +#include +#include "src/expression/ops/biasadd.h" +#include "src/expression/ops/depend.h" +#include "src/expression/ops/transpose.h" +#include "nnacl/conv_parameter.h" +#include "inner/model_generated.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +ConvM::ConvM(const ConvConfig &cfg) : Node() { + auto op_param = calloc(1, sizeof(ConvParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ConvParameter"; + return; + } + SetOpParam(op_param); + ConvParameter *conv_param = reinterpret_cast(OpParam()); + conv_param->input_channel_ = cfg.in_channel_; + conv_param->output_channel_ = cfg.out_channel_; + conv_param->kernel_h_ = cfg.kernel_size_[0]; + conv_param->kernel_w_ = cfg.kernel_size_[1]; + conv_param->stride_h_ = cfg.stride_[0]; + conv_param->stride_w_ = cfg.stride_[1]; + auto pad_mode = GetMode(cfg.pad_mode_); + if (pad_mode == -1) { + MS_LOG(ERROR) << "bad pad mode"; + return; + } + conv_param->pad_mode_ = static_cast(pad_mode); + conv_param->pad_u_ = cfg.padding_[C0NUM]; + conv_param->pad_d_ = cfg.padding_[C1NUM]; + conv_param->pad_l_ = cfg.padding_[C2NUM]; + conv_param->pad_r_ = cfg.padding_[C3NUM]; + conv_param->dilation_h_ = cfg.dilation_[C0NUM]; + conv_param->dilation_w_ = cfg.dilation_[C1NUM]; + conv_param->group_ = cfg.group_; + conv_param->out_format_ = NHWC; + conv_param->act_type_ = ActType_No; + expr()->SetSize(C2NUM); + set_primitive(schema::PrimitiveType_Conv2DFusion); + set_name(UniqueName("Conv")); + Param::Mode mode = Param::String2Enum(cfg.weight_init_); + std::vector dims = {conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_, + conv_param->input_channel_ / conv_param->group_}; + auto w = CreateWeights(dims, kNumberTypeFloat32, KHWC, mode, "weights"); + expr()->set_params(C1NUM, w); + if (cfg.has_bias) { + bias_ = new (std::nothrow) BiasAddM(KHWC); + if (bias_ == nullptr) { + MS_LOG(ERROR) << "Cannot allocate bias"; + return; + } + bias_->update_name(name()); + std::vector dim_bias = {conv_param->output_channel_}; + wbias_ = CreateWeights(dim_bias, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "weights"); + AddLearn(wbias_->node()); + PushOp(bias_); + } + SetLearn(); +} + +std::vector ConvM::construct(const std::vector &inputs) { + auto in = inputs; + auto x = in.front(); + if (x->format() != NHWC && x->dims().size() == C4NUM) { + x = TransposeM::TransposeCHW2HWC(x); + x->node()->set_name(name() + "/" + x->node()->name()); + PushOp(x->node()); + in.at(0) = x; + } + auto y = Node::construct(in); + if (bias_ != nullptr) { + y = (*bias_)({y.front(), wbias_}); + } + return y; +} + +void ConvM::SetLearn() { AddLearn(input(C1NUM)->node()); } + +int ConvM::GetMode(std::string mode) { + const std::vector list = {"pad", "same", "valid"}; + auto itr = std::find(list.begin(), list.end(), mode); + if (itr == list.end()) { + MS_LOG(ERROR) << "illegal mode" << mode; + return -1; + } + return std::distance(list.begin(), itr); +} + +std::vector ConvM::Grad(EXPR *yt) { + // Generate Input Grad + EXPR *in = yt; + if (yt->format() != NHWC && yt->dims().size() == C4NUM) { + in = TransposeM::TransposeCHW2HWC(yt); + in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); + PushOp(in->node()); + } + auto inGrad = new (std::nothrow) ConvInputGradM(this); + if (inGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate convolution input grad"; + return {}; + } + PushOp(inGrad); + auto ig = (*inGrad)({in, input(1), inGrad->input(2)}); + // Execution Control Flow ! + auto depend = NN::Depend(); + if (depend == nullptr) { + MS_LOG(ERROR) << "Cannot allocate depend"; + return {}; + } + PushOp(depend); + depend->update_name(name()); + auto de = (*depend)({inGrad->expr()}); + // Generate Filter Grad + auto filterGrad = new (std::nothrow) ConvFilterGradM(this); + if (filterGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate convolution filter grad"; + return {}; + } + PushOp(filterGrad); + filterGrad->update_name(name()); + auto fg = (*filterGrad)({in, input(0), filterGrad->input(2), de[0]}); + std::vector res = {ig[0], fg[0]}; + return res; +} + +int ConvM::UnPopulate(const std::unique_ptr &cnode) { + auto conv_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::Conv2DFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->activation_type = static_cast(conv_param->act_type_); + prim->format = static_cast(conv_param->out_format_); + prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; + prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; + prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; + prim->out_channel = conv_param->output_channel_; + prim->in_channel = conv_param->input_channel_; + prim->group = conv_param->group_; + prim->pad_mode = static_cast(conv_param->pad_mode_); + prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; + prim->mode = 1; + cnode->primitive->value.value = prim; + return RET_OK; +} + +ConvInputGradM::ConvInputGradM(ConvM *conv_node) : Node() { + CloneOpParam(conv_node->OpParam()); + set_primitive(schema::PrimitiveType_Conv2DBackpropInputFusion); + set_name(kGradName + "/conv2DBackpropInput"); + expr()->SetSize(C3NUM); + auto const x = conv_node->input(0); + CreateConstTensor(C2NUM, {static_cast(x->dims().size())}, kNumberTypeInt32, KHWC, "shape", x->dims().data()); +} + +int ConvInputGradM::UnPopulate(const std::unique_ptr &cnode) { + auto conv_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::Conv2DBackpropInputFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->activation_type = static_cast(conv_param->act_type_); + prim->format = static_cast(conv_param->out_format_); + prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; + prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; + prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; + prim->out_channel = conv_param->output_channel_; + prim->in_channel = conv_param->input_channel_; + prim->group = conv_param->group_; + prim->pad_mode = static_cast(conv_param->pad_mode_); + prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; + cnode->primitive->value.value = prim; + return RET_OK; +} + +ConvFilterGradM::ConvFilterGradM(ConvM *conv_node) : Node() { + CloneOpParam(conv_node->OpParam()); + set_primitive(schema::PrimitiveType_Conv2DBackpropFilterFusion); + set_name(kGradName + "/conv2DBackpropFilter"); + expr()->SetSize(C4NUM); + auto w = conv_node->input(1); + CreateConstTensor(C2NUM, {static_cast(w->dims().size())}, kNumberTypeInt32, KHWC, "shape", w->dims().data()); +} +int ConvFilterGradM::UnPopulate(const std::unique_ptr &cnode) { + auto conv_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::Conv2DBackpropFilterFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->activation_type = static_cast(conv_param->act_type_); + prim->format = static_cast(conv_param->out_format_); + prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; + prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; + prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; + prim->out_channel = conv_param->output_channel_; + prim->in_channel = conv_param->input_channel_; + prim->group = conv_param->group_; + prim->pad_mode = static_cast(conv_param->pad_mode_); + prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Conv2DFusion, ReturnNode); + +namespace NN { +Node *Conv2D(const ConvConfig &cfg) { + auto c = new (std::nothrow) ConvM(cfg); + if (c == nullptr) { + MS_LOG(ERROR) << "Cannot allocate Convolution object"; + return nullptr; + } + return c; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/conv.h b/mindspore/lite/src/expression/ops/conv.h new file mode 100644 index 00000000000..32fc6632255 --- /dev/null +++ b/mindspore/lite/src/expression/ops/conv.h @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ + +#include +#include +#include +#include "src/expression/cfg.h" +#include "src/expression/node.h" + +namespace mindspore { +namespace lite { +class ConvM : public Node { + public: + ConvM() = default; + explicit ConvM(const ConvConfig &cfg); + std::vector construct(const std::vector &inputs) override; + Param *weight() override { return input(1)->node()->data(); } + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + void SetLearn() override; + + private: + int GetMode(std::string mode); + Node *bias_ = nullptr; + EXPR *wbias_ = nullptr; +}; + +class ConvInputGradM : public Node { + public: + explicit ConvInputGradM(ConvM *conv_node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class ConvFilterGradM : public Node { + public: + explicit ConvFilterGradM(ConvM *conv_node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ diff --git a/mindspore/lite/src/expression/ops/dense.cc b/mindspore/lite/src/expression/ops/dense.cc new file mode 100644 index 00000000000..02624a708e3 --- /dev/null +++ b/mindspore/lite/src/expression/ops/dense.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/dense.h" +#include +#include "include/api/cfg.h" +#include "src/expression/ops/biasadd.h" +#include "src/expression/ops/depend.h" +#include "src/expression/ops.h" +#include "nnacl/matmul_parameter.h" +#include "src/expression/import.h" +#include "inner/model_generated.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +DenseM::DenseM(const DenseConfig &cfg) : Node() { + auto op_param = calloc(1, sizeof(MatMulParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate MatMulParameter"; + return; + } + set_name(UniqueName("Dense")); + SetOpParam(op_param); + expr()->SetSize(C2NUM); + set_primitive(schema::PrimitiveType_MatMulFusion); + auto param = reinterpret_cast(opParam_.get()); + param->row_ = cfg.out_channels_; + param->col_ = cfg.in_channels_; + param->a_transpose_ = false; + param->b_transpose_ = true; + std::vector dims = {param->row_, param->col_}; + auto w = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::NORMAL, "weights"); + expr()->set_params(C1NUM, w); + if (cfg.has_bias_) { + wbias_ = CreateWeights({cfg.out_channels_}, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "bias_weights"); + bias_ = new (std::nothrow) BiasAddM(KHWC); + if (bias_ == nullptr) { + MS_LOG(ERROR) << "Cannot allocate bias"; + return; + } + bias_->update_name(name()); + AddLearn(wbias_->node()); + PushOp(bias_); + } + SetLearn(); +} + +std::vector DenseM::construct(const std::vector &inputs) { + auto x = Node::construct(inputs); + if (bias_ != nullptr) { + x = (*bias_)({x.front(), wbias_}); + } + return x; +} + +std::vector DenseM::Grad(EXPR *yt) { + auto src_param = reinterpret_cast(opParam_.get()); + bool ta = src_param->a_transpose_; + bool tb = src_param->b_transpose_; + + // dx grad op + auto dxGrad = new (std::nothrow) DenseM(); + if (dxGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate dxGrad "; + return {}; + } + PushOp(dxGrad); + dxGrad->CloneOpParam(opParam_); + dxGrad->set_primitive(schema::PrimitiveType_MatMulFusion); + auto dxGradParam = reinterpret_cast(dxGrad->OpParam()); + dxGradParam->a_transpose_ = (ta && tb); + dxGradParam->b_transpose_ = (ta || !tb); + dxGrad->set_name(name() + kGradName + "/dxGrad"); + EXPR *dx = nullptr; + if (ta) { + dx = (*dxGrad)({input(1), yt}).front(); + } else { + dx = (*dxGrad)({yt, input(1)}).front(); + } + // Control execution flow + auto depend = NN::Depend(); + if (depend == nullptr) { + MS_LOG(ERROR) << "Cannot allocate depend "; + return {}; + } + PushOp(depend); + auto de = (*depend)({dxGrad->expr()}).front(); + + // dw grad op + auto dwGrad = new (std::nothrow) DenseM(); + if (dwGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate dwGrad "; + return {}; + } + PushOp(dwGrad); + dwGrad->CloneOpParam(opParam_); + dwGrad->set_primitive(schema::PrimitiveType_MatMulFusion); + auto dwGradParam = reinterpret_cast(dwGrad->OpParam()); + dwGradParam->a_transpose_ = (!ta || tb); + dwGradParam->b_transpose_ = ta && tb; + dwGrad->set_name(name() + kGradName + "/dwGrad"); + EXPR *dw = nullptr; + if (tb) { + dw = (*dwGrad)({yt, input(0), de}).front(); + } else { + dw = (*dwGrad)({input(0), yt, de}).front(); + } + return {dx, dw}; +} +int DenseM::UnPopulate(const std::unique_ptr &cnode) { + auto dense_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::MatMulFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->transpose_a = dense_param->a_transpose_; + prim->transpose_b = dense_param->b_transpose_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +void DenseM::SetLearn() { AddLearn(input(C1NUM)->node()); } + +static ImportReg reg(schema::PrimitiveType_MatMulFusion, ReturnNode); + +namespace NN { +Node *Dense(const DenseConfig &cfg) { + auto l = new (std::nothrow) DenseM(cfg); + if (l == nullptr) { + MS_LOG(ERROR) << "Cannot allocate Dense object"; + } + return l; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/dense.h b/mindspore/lite/src/expression/ops/dense.h new file mode 100644 index 00000000000..10734336afc --- /dev/null +++ b/mindspore/lite/src/expression/ops/dense.h @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ + +#include +#include +#include "src/expression/node.h" +#include "src/expression/cfg.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class DenseM : public Node { + public: + DenseM() = default; + explicit DenseM(const DenseConfig &cfg); + std::vector construct(const std::vector &inputs) override; + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + void SetLearn() override; + + private: + Param *weight() override { return input(1)->node()->data(); } + Node *bias_{nullptr}; + EXPR *wbias_{nullptr}; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ diff --git a/mindspore/lite/src/expression/ops/depend.cc b/mindspore/lite/src/expression/ops/depend.cc new file mode 100644 index 00000000000..c6aee1532e1 --- /dev/null +++ b/mindspore/lite/src/expression/ops/depend.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/depend.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +DependM::DependM() : Node() { + auto param = calloc(1, sizeof(OpParameter)); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + SetOpParam(param); + set_primitive(schema::PrimitiveType_Depend); + set_name(UniqueName("Depend")); +} +namespace NN { +Node *Depend() { + auto d = new (std::nothrow) DependM(); + if (d == nullptr) { + MS_LOG(ERROR) << "Cannot allocate depend object"; + return nullptr; + } + return d; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/depend.h b/mindspore/lite/src/expression/ops/depend.h new file mode 100644 index 00000000000..0995e6644dd --- /dev/null +++ b/mindspore/lite/src/expression/ops/depend.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ + +#include +#include "src/expression/node.h" +#include "src/expression/ops.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class DependM : public Node { + public: + DependM(); +}; + +namespace NN { +Node *Depend(); +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ diff --git a/mindspore/lite/src/expression/ops/dropout.cc b/mindspore/lite/src/expression/ops/dropout.cc new file mode 100644 index 00000000000..e1c33638072 --- /dev/null +++ b/mindspore/lite/src/expression/ops/dropout.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/dropout.h" +#include +#include "nnacl/fp32_grad/dropout_parameter.h" +#include "inner/model_generated.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +DropOutM::DropOutM(float ratio) { + auto param = reinterpret_cast(calloc(1, sizeof(DropoutParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + param->ratio_ = ratio; + SetOpParam(param); + set_primitive(schema::PrimitiveType_Dropout); + set_name(UniqueName("DropOut")); +} + +std::vector DropOutM::Grad(EXPR *yt) { + auto inGrad = new (std::nothrow) DropOutGradM(this); + if (inGrad == nullptr) { + MS_LOG(ERROR) << "Cannot allocate drop grad"; + return {}; + } + return (*inGrad)({yt, expr()}); +} + +int DropOutM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::DropoutT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->keep_prob = param->ratio_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +DropOutGradM::DropOutGradM(DropOutM *node) { + CloneOpParam(node->OpParam()); + set_primitive(schema::PrimitiveType_DropoutGrad); + set_name(kGradName + "/DropOutGrad"); +} + +int DropOutGradM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::DropoutGradT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->keep_prob = param->ratio_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Dropout, ReturnNode); + +namespace NN { +Node *DropOut(float ratio) { + auto node = new (std::nothrow) DropOutM(ratio); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate dropout node"; + return nullptr; + } + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/dropout.h b/mindspore/lite/src/expression/ops/dropout.h new file mode 100644 index 00000000000..dce87f188bd --- /dev/null +++ b/mindspore/lite/src/expression/ops/dropout.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ + +#include +#include +#include "src/expression/node.h" + +namespace mindspore { +namespace lite { +class DropOutM : public Node { + public: + DropOutM() = default; + explicit DropOutM(float ratio); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class DropOutGradM : public Node { + public: + explicit DropOutGradM(DropOutM *node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ diff --git a/mindspore/lite/src/expression/ops/flatten.cc b/mindspore/lite/src/expression/ops/flatten.cc new file mode 100644 index 00000000000..846a350f95b --- /dev/null +++ b/mindspore/lite/src/expression/ops/flatten.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/flatten.h" +#include +#include "inner/model_generated.h" +#include "src/expression/import.h" +#include "src/expression/ops.h" + +#include "src/cxx_api/expression/node_impl.h" +#include "nnacl/op_base.h" + +namespace mindspore { +namespace lite { +FlattenM::FlattenM(int dummy) { + auto param = reinterpret_cast(calloc(C1NUM, sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + SetOpParam(param); + set_primitive(schema::PrimitiveType_Flatten); + set_name(UniqueName("Flatten")); +} + +std::vector FlattenM::construct(const std::vector &inputs) { + auto in = inputs; + auto y = Node::construct(in); + return y; +} + +std::vector FlattenM::Grad(EXPR *yt) { + auto shape_of_x = input(0)->dims(); + auto reshape = NN::Reshape(shape_of_x); + PushOp(reshape); + return (*reshape)({yt}); +} + +int FlattenM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::DropoutT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Flatten, ReturnNode); + +namespace NN { +Node *Flatten() { + auto node = new (std::nothrow) FlattenM(0); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/flatten.h b/mindspore/lite/src/expression/ops/flatten.h new file mode 100644 index 00000000000..0be7d5bbd03 --- /dev/null +++ b/mindspore/lite/src/expression/ops/flatten.h @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ + +#include +#include +#include "src/expression/node.h" + +namespace mindspore { +namespace lite { +class FlattenM : public Node { + public: + FlattenM() = default; + explicit FlattenM(int dummy); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + std::vector construct(const std::vector &inputs) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ diff --git a/mindspore/lite/src/expression/ops/pooling.cc b/mindspore/lite/src/expression/ops/pooling.cc new file mode 100644 index 00000000000..ad23ee9d918 --- /dev/null +++ b/mindspore/lite/src/expression/ops/pooling.cc @@ -0,0 +1,215 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/pooling.h" +#include "src/expression/ops.h" +#include "src/expression/import.h" +#include "src/cxx_api/expression/node_impl.h" +#include "src/expression/ops/transpose.h" + +namespace mindspore { +namespace lite { +PoolingM::PoolingM(const PoolingConfig &cfg) : Node() { + auto op_param = calloc(1, sizeof(PoolingParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate PoolingParameter"; + return; + } + SetOpParam(op_param); + PoolingParameter *pool_param = reinterpret_cast(OpParam()); + + pool_param->window_h_ = cfg.kernel_size_[0]; + pool_param->window_w_ = cfg.kernel_size_[1]; + pool_param->stride_h_ = cfg.stride_[0]; + pool_param->stride_w_ = cfg.stride_[1]; + auto pad_mode = GetMode(cfg.pad_mode_); + if (pad_mode == -1) { + MS_LOG(ERROR) << "bad pad mode"; + return; + } + pool_param->pad_mode_ = static_cast(pad_mode + Pad_pad); + pool_param->round_mode_ = RoundMode_Floor; + pool_param->act_type_ = ActType_No; +} + +std::vector PoolingM::construct(const std::vector &inputs) { + auto in = inputs; + auto x = in.front(); + if (x->format() != NHWC && x->dims().size() == C4NUM) { + x = TransposeM::TransposeCHW2HWC(x); + x->node()->set_name(name() + "/" + x->node()->name()); + PushOp(x->node()); + in.at(0) = x; + } + auto y = Node::construct(in); + return y; +} + +int PoolingM::GetMode(std::string mode) { + const std::vector list = {"same", "valid"}; + auto itr = std::find(list.begin(), list.end(), mode); + if (itr == list.end()) { + MS_LOG(ERROR) << "illegal mode" << mode; + return -1; + } + return std::distance(list.begin(), itr); +} + +void PoolingM::UpdateRoundMode(const PoolingParameter *param, schema::RoundMode *round_mode) { + switch (param->round_mode_) { + case RoundMode_Floor: + *round_mode = schema::RoundMode_FLOOR; + break; + case RoundMode_Ceil: + *round_mode = schema::RoundMode_CEIL; + break; + default: + *round_mode = schema::RoundMode_FLOOR; + break; + } +} + +template +int PoolingM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) T; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->kernel_size = {param->window_h_, param->window_w_}; + prim->strides = {param->stride_h_, param->stride_w_}; + prim->pad = {param->pad_u_, param->pad_d_, param->pad_l_, param->pad_r_}; + prim->pad_mode = static_cast(param->pad_mode_); + UpdateRoundMode(param, &prim->round_mode); + prim->global = param->global_; + prim->activation_type = schema::ActivationType_NO_ACTIVATION; + cnode->primitive->value.value = prim; + return RET_OK; +} + +template +int PoolingM::UnPopulateGrad(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) T; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->kernel_size = {param->window_h_, param->window_w_}; + prim->strides = {param->stride_h_, param->stride_w_}; + prim->pad_mode = static_cast(param->pad_mode_); + cnode->primitive->value.value = prim; + return RET_OK; +} + +// Max pooling Definition +MaxPoolM::MaxPoolM(const PoolingConfig &cfg) : PoolingM(cfg) { + auto param = reinterpret_cast(OpParam()); + param->pool_mode_ = PoolMode_MaxPool; + set_primitive(schema::PrimitiveType_MaxPoolFusion); + set_name(UniqueName("MaxPool")); +} + +int MaxPoolM::UnPopulate(const std::unique_ptr &cnode) { + return PoolingM::UnPopulate(cnode); +} + +std::vector MaxPoolM::Grad(EXPR *yt) { + auto in = yt; + if (yt->format() != NHWC && yt->dims().size() == C4NUM) { + in = TransposeM::TransposeCHW2HWC(yt); + in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); + PushOp(in->node()); + } + auto pool_grad = new (std::nothrow) MaxPoolGradM(this); + PushOp(pool_grad); + return (*pool_grad)({input(0), output(0), in}); +} + +static ImportReg maxPoolReg(schema::PrimitiveType_MaxPoolFusion, ReturnNode); + +// Avg pooling Definition +AvgPoolM::AvgPoolM(const PoolingConfig &cfg) : PoolingM(cfg) { + auto param = reinterpret_cast(OpParam()); + param->pool_mode_ = PoolMode_AvgPool; + set_primitive(schema::PrimitiveType_AvgPoolFusion); + set_name(UniqueName("AvgPool")); +} + +int AvgPoolM::UnPopulate(const std::unique_ptr &cnode) { + return PoolingM::UnPopulate(cnode); +} + +std::vector AvgPoolM::Grad(EXPR *yt) { + auto in = yt; + if (yt->format() != NHWC && yt->dims().size() == C4NUM) { + in = TransposeM::TransposeCHW2HWC(yt); + in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); + PushOp(in->node()); + } + auto pool_grad = new (std::nothrow) AvgPoolGradM(this); + PushOp(pool_grad); + return (*pool_grad)({input(0), output(0), in}); +} + +static ImportReg avgPoolReg(schema::PrimitiveType_AvgPoolFusion, ReturnNode); + +// Max Pool Grad Definition +MaxPoolGradM::MaxPoolGradM(MaxPoolM *node) { + Node(); + CloneOpParam(node->OpParam()); + set_primitive(schema::PrimitiveType_MaxPoolGrad); + set_name(kGradName + "/" + node->name() + "/MaxPoolGrad"); +} + +int MaxPoolGradM::UnPopulate(const std::unique_ptr &cnode) { + return PoolingM::UnPopulateGrad(cnode); +} + +// Avg Pool Grad Definition +AvgPoolGradM::AvgPoolGradM(AvgPoolM *node) { + Node(); + CloneOpParam(node->OpParam()); + set_primitive(schema::PrimitiveType_AvgPoolGrad); + set_name(kGradName + "/" + node->name() + "/AvgPoolGrad"); +} + +int AvgPoolGradM::UnPopulate(const std::unique_ptr &cnode) { + return PoolingM::UnPopulateGrad(cnode); +} + +namespace NN { +Node *MaxPool2D(const PoolingConfig &cfg) { + auto c = new (std::nothrow) MaxPoolM(cfg); + if (c == nullptr) { + MS_LOG(ERROR) << "Cannot allocate max pool object"; + return nullptr; + } + return c; +} + +Node *AvgPool2D(const PoolingConfig &cfg) { + auto c = new (std::nothrow) AvgPoolM(cfg); + if (c == nullptr) { + MS_LOG(ERROR) << "Cannot allocate average pool object"; + return nullptr; + } + return c; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/pooling.h b/mindspore/lite/src/expression/ops/pooling.h new file mode 100644 index 00000000000..881996a3ba3 --- /dev/null +++ b/mindspore/lite/src/expression/ops/pooling.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ + +#include +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" +#include "src/expression/cfg.h" +#include "nnacl/pooling_parameter.h" + +namespace mindspore { +namespace lite { +class PoolingM : public Node { + public: + PoolingM() = default; + explicit PoolingM(const PoolingConfig &cfg); + template + int UnPopulate(const std::unique_ptr &cnode); + template + int UnPopulateGrad(const std::unique_ptr &cnode); + std::vector construct(const std::vector &inputs); + + private: + void UpdateRoundMode(const PoolingParameter *param, enum schema::RoundMode *round_mode); + int GetMode(std::string mode); +}; + +class MaxPoolM : public PoolingM { + public: + MaxPoolM() = default; + explicit MaxPoolM(const PoolingConfig &cfg); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class AvgPoolM : public PoolingM { + public: + AvgPoolM() = default; + explicit AvgPoolM(const PoolingConfig &cfg); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class MaxPoolGradM : public PoolingM { + public: + explicit MaxPoolGradM(MaxPoolM *node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +class AvgPoolGradM : public PoolingM { + public: + explicit AvgPoolGradM(AvgPoolM *node); + int UnPopulate(const std::unique_ptr &cnode) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ diff --git a/mindspore/lite/src/expression/ops/reduce.cc b/mindspore/lite/src/expression/ops/reduce.cc new file mode 100644 index 00000000000..d7527b3bdbc --- /dev/null +++ b/mindspore/lite/src/expression/ops/reduce.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/reduce.h" +#include +#include "src/expression/ops/tile.h" +#include "src/expression/ops/reshape.h" +#include "src/expression/ops/arithmetic.h" +#include "src/expression/ops.h" +#include "src/expression/ops_utils.h" +#include "src/expression/import.h" +#include "nnacl/reduce_parameter.h" +#include "src/cxx_api/expression/node_impl.h" + +namespace mindspore { +namespace lite { +ReduceM::ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector &axis) : Node() { + expr()->SetSize(C2NUM); + ReduceParameter *param = reinterpret_cast(calloc(1, sizeof(ReduceParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + param->mode_ = mode; + param->keep_dims_ = keep_dims; + param->reduce_to_end_ = false; + param->coeff = 1.f; + SetOpParam(param); + set_name(UniqueName("Reduce")); + set_primitive(schema::PrimitiveType_ReduceFusion); + Node::CreateConstTensor(C1NUM, {static_cast(axis.size())}, kNumberTypeInt32, KHWC, "axis", axis.data()); +} + +int ReduceM::UnPopulate(const std::unique_ptr &cnode) { + auto reduce_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::ReduceFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->keep_dims = reduce_param->keep_dims_; + prim->mode = static_cast(reduce_param->mode_); + prim->coeff = reduce_param->coeff; + prim->reduce_to_end = reduce_param->reduce_to_end_; + cnode->primitive->value.value = prim; + return RET_OK; +} + +std::vector ReduceM::Grad(EXPR *yt) { + auto shape_of_x = input(0)->dims(); + std::vector shape_of_axis; + + auto data = input(1)->node()->data()->data().data(); + int size = input(1)->dims().at(0); + auto int_data = reinterpret_cast(data); + for (int i = 0; i < size; i++) { + shape_of_axis.push_back(int_data[i]); + } + + // assume no dynamic shape + ShapeReduce reduce_shape; + auto output_shape_kept_dims = ShapeReduce()(shape_of_x, shape_of_axis); + auto tile_scaling = VectorDiv()(shape_of_x, output_shape_kept_dims); + auto reshape = NN::Reshape(output_shape_kept_dims); + PushOp(reshape); + reshape->set_name(name() + "/reshape"); + auto g = (*reshape)({yt}).front(); + auto tile = NN::Tile(tile_scaling); + PushOp(tile); + tile->set_name(name() + "/tile"); + auto sum_grad = (*tile)({g}).front(); + auto reduce_param = reinterpret_cast(OpParam()); + if (reduce_param->mode_ == schema::ReduceMode_ReduceSum) { + return {sum_grad}; + } else if (reduce_param->mode_ == schema::ReduceMode_ReduceMean) { + auto shape_of_y = output(0)->dims(); + auto shape_x_mul = std::accumulate(shape_of_x.begin(), shape_of_x.end(), 1, std::multiplies()); + auto shape_y_mul = std::accumulate(shape_of_y.begin(), shape_of_y.end(), 1, std::multiplies()); + auto div_shape = static_cast(shape_x_mul) / static_cast(shape_y_mul); + auto div_op = NN::Div(); + PushOp(div_op); + auto d = div_op->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "div_shape", &div_shape); + auto dx = (*div_op)({sum_grad, d->expr()}); + return dx; + } else { + return {}; + } +} + +static ImportReg reg(schema::PrimitiveType_ReduceFusion, ReturnNode); + +namespace NN { +Node *ReduceSum(bool keep_dims, const std::vector &axis) { + auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceSum, keep_dims, axis); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate reduce sum node"; + return nullptr; + } + node->set_name(Node::UniqueName("ReduceSum")); + return node; +} +Node *ReduceMean(bool keep_dims, const std::vector &axis) { + auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceMean, keep_dims, axis); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate reduce mean node"; + return nullptr; + } + node->set_name(Node::UniqueName("ReduceMean")); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/reduce.h b/mindspore/lite/src/expression/ops/reduce.h new file mode 100644 index 00000000000..1ca0e9216fb --- /dev/null +++ b/mindspore/lite/src/expression/ops/reduce.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ + +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class ReduceM : public Node { + public: + ReduceM() = default; + ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector &axis); + Param *weight() override { return input(1)->node()->data(); } + int UnPopulate(const std::unique_ptr &cnode) override; + std::vector Grad(EXPR *expr) override; +}; + +namespace NN { +Node *ReduceMean(bool keep_dims, const std::vector &axis); +Node *ReduceSum(bool keep_dims, const std::vector &axis); +} // namespace NN +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ diff --git a/mindspore/lite/src/expression/ops/reshape.cc b/mindspore/lite/src/expression/ops/reshape.cc new file mode 100644 index 00000000000..d6cc04338cd --- /dev/null +++ b/mindspore/lite/src/expression/ops/reshape.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/reshape.h" +#include "src/expression/ops.h" +#include "nnacl/reshape_parameter.h" +#include "src/expression/import.h" + +namespace mindspore { +namespace lite { +ReshapeM::ReshapeM(const std::vector &shape) : Node() { + auto op_param = calloc(1, sizeof(ReshapeParameter)); + if (op_param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ReshapeParameter"; + return; + } + set_name(UniqueName("Reshape")); + expr()->SetSize(C2NUM); + SetOpParam(op_param); + set_primitive(schema::PrimitiveType_Reshape); + + ReshapeParameter *reshape_param = reinterpret_cast(opParam_.get()); + reshape_param->shape_dim_ = shape.size(); + for (int i = 0; i < reshape_param->shape_dim_; i++) { + reshape_param->shape_[i] = shape.at(i); + } + Node::CreateConstTensor(C1NUM, {static_cast(shape.size())}, kNumberTypeInt32, KHWC, "shape", shape.data()); +} + +int ReshapeM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::ReshapeT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +std::vector ReshapeM::Grad(EXPR *yt) { + auto shape_of_x = input(0)->dims(); + auto reshape = NN::Reshape(shape_of_x); + PushOp(reshape); + return (*reshape)({yt}); +} + +static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode); + +namespace NN { +Node *Reshape(const std::vector &shape) { + auto node = new (std::nothrow) ReshapeM(shape); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate reshape node"; + return nullptr; + } + node->set_name(Node::UniqueName("Reshape")); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/reshape.h b/mindspore/lite/src/expression/ops/reshape.h new file mode 100644 index 00000000000..a7c153772a8 --- /dev/null +++ b/mindspore/lite/src/expression/ops/reshape.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ + +#include +#include + +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class ReshapeM : public Node { + public: + ReshapeM() = default; + explicit ReshapeM(const std::vector &shape); + int UnPopulate(const std::unique_ptr &cnode) override; + std::vector Grad(EXPR *expr) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ diff --git a/mindspore/lite/src/expression/ops/softmax.cc b/mindspore/lite/src/expression/ops/softmax.cc new file mode 100644 index 00000000000..7cbaebca57d --- /dev/null +++ b/mindspore/lite/src/expression/ops/softmax.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/softmax.h" +#include "nnacl/softmax_parameter.h" +#include "inner/model_generated.h" +#include "src/expression/import.h" +#include "src/expression/ops/reshape.h" +#include "src/expression/ops/reduce.h" +#include "src/expression/ops/arithmetic.h" +#include "src/expression/ops/transpose.h" +#include "src/cxx_api/expression/node_impl.h" +#include "src/expression/ops.h" + +namespace mindspore { +namespace lite { +SoftmaxM::SoftmaxM(int axis) { + auto param = reinterpret_cast(calloc(1, sizeof(SoftmaxParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + param->axis_ = axis; + SetOpParam(param); + set_primitive(schema::PrimitiveType_Softmax); + set_name(UniqueName("Softmax")); +} + +std::vector SoftmaxM::getTransposeAxis(const std::vector &shape, int axis) { + int rank = shape.size(); + if (axis < 0) { + axis += rank; + } + std::vector reverse_axis(rank); + std::iota(reverse_axis.begin(), reverse_axis.end(), 0); + reverse_axis.at(axis) = rank - 1; + reverse_axis.at(rank - 1) = axis; + return reverse_axis; +} + +std::vector SoftmaxM::Grad(EXPR *yt) { + auto x = input(0); + auto out = output(0); + auto shape_of_x = x->dims(); + auto param = reinterpret_cast(OpParam()); + auto reverse_axis = getTransposeAxis(shape_of_x, param->axis_); + + auto transpose_out = NN::Transpose(reverse_axis); + transpose_out->set_name(kGradName + "/" + name() + "/" + transpose_out->name() + "/out/"); + PushOp(transpose_out); + auto y_trn = (*transpose_out)({out}).front(); + + auto transpose_dout = NN::Transpose(reverse_axis); + transpose_dout->set_name(kGradName + "/" + name() + "/" + transpose_dout->name() + "/dout/"); + PushOp(transpose_dout); + auto yt_trn = (*transpose_dout)({yt}).front(); + + auto mul0 = NN::Mul(); + mul0->set_name(kGradName + "/" + name() + "/" + mul0->name() + "0"); + PushOp(mul0); + auto tmp0 = (*mul0)({y_trn, yt_trn}).front(); + + auto sum_func = NN::ReduceSum(true, {-1}); + sum_func->set_name(kGradName + "/" + name() + "/" + sum_func->name()); + PushOp(sum_func); + auto tmp1 = (*sum_func)({tmp0}).front(); + + auto sub = NN::Sub(); + sub->set_name(kGradName + "/" + name() + "/" + sub->name()); + PushOp(sub); + auto tmp2 = (*sub)({yt_trn, tmp1}).front(); + + auto mul1 = NN::Mul(); + mul1->set_name(kGradName + "/" + name() + "/" + mul1->name() + "1"); + PushOp(mul1); + auto tmp3 = (*mul1)({y_trn, tmp2}); + + auto transpose_dx = NN::Transpose(reverse_axis); + transpose_dx->set_name(kGradName + "/" + name() + "/" + transpose_dx->name() + "/dx"); + PushOp(transpose_dx); + auto dx = (*transpose_dx)({tmp3}); + return dx; +} + +int SoftmaxM::UnPopulate(const std::unique_ptr &cnode) { + auto param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::SoftmaxT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + prim->axis.push_back(param->axis_); + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Softmax, ReturnNode); + +namespace NN { +Node *Softmax(int axis) { + auto node = new (std::nothrow) SoftmaxM(axis); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/softmax.h b/mindspore/lite/src/expression/ops/softmax.h new file mode 100644 index 00000000000..7347984e91d --- /dev/null +++ b/mindspore/lite/src/expression/ops/softmax.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ + +#include +#include +#include "src/expression/node.h" + +namespace mindspore { +namespace lite { +class SoftmaxM : public Node { + public: + SoftmaxM() = default; + explicit SoftmaxM(int axis); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + + private: + std::vector getTransposeAxis(const std::vector &shape, int axis); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ diff --git a/mindspore/lite/src/expression/ops/softmaxCE.cc b/mindspore/lite/src/expression/ops/softmaxCE.cc new file mode 100644 index 00000000000..facee1c15e4 --- /dev/null +++ b/mindspore/lite/src/expression/ops/softmaxCE.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/softmaxCE.h" +#include "include/api/net.h" +#include "src/cxx_api/expression/node_impl.h" +#include "src/expression/ops/reduce.h" +namespace mindspore { +namespace NN { +Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) { + auto lite_node = lite::NN::SoftmaxCrossEntropy(cfg); + return NodeImpl::Connect(lite_node); +} +} // namespace NN + +namespace lite { +SoftmaxCrossEntropyM::SoftmaxCrossEntropyM() { + auto param = calloc(1, sizeof(OpParameter)); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate parameter"; + return; + } + expr()->SetSize(C2NUM); + SetOpParam(param); + set_name("SoftmaxCrossEntropy"); + set_primitive(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits); + EXPR e(this); + e.SetSize(0); + expr_.emplace_back(e); +} + +Node *SoftmaxCrossEntropyM::GetReductionNode(const std::string &mode, const std::vector &axis) { + if (mode == "mean") { + return NN::ReduceMean(false, axis); + } else if (mode == "sum") { + return NN::ReduceSum(false, axis); + } else { + return nullptr; + } +} + +SoftmaxCrossEntropyM::SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg) : SoftmaxCrossEntropyM() { + std::vector axis = {0}; + reduce_ = GetReductionNode(cfg.reduction, axis); + if (reduce_ != nullptr) { + PushOp(reduce_); + } +} + +std::vector SoftmaxCrossEntropyM::construct(const std::vector &inputs) { + auto y = Node::construct(inputs); + if (reduce_ != nullptr) { + y = (*reduce_)({y.front()}); + } + return y; +} + +std::vector SoftmaxCrossEntropyM::Grad(EXPR *expr) { return {this->expr(1)}; } + +int SoftmaxCrossEntropyM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::SoftmaxCrossEntropyWithLogitsT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +namespace NN { +Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) { + auto s = new (std::nothrow) SoftmaxCrossEntropyM(cfg); + if (s == nullptr) { + MS_LOG(ERROR) << "Cannot allocate softmax node"; + } + return s; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/softmaxCE.h b/mindspore/lite/src/expression/ops/softmaxCE.h new file mode 100644 index 00000000000..1c0c516d52c --- /dev/null +++ b/mindspore/lite/src/expression/ops/softmaxCE.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ + +#include +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" +#include "include/api/net.h" + +namespace mindspore { +namespace lite { +class SoftmaxCrossEntropyM : public Node { + public: + SoftmaxCrossEntropyM(); + explicit SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg); + std::vector Grad(EXPR *expr) override; + int UnPopulate(const std::unique_ptr &cnode) override; + std::vector construct(const std::vector &inputs) override; + + private: + Node *GetReductionNode(const std::string &mode, const std::vector &axis); + Node *reduce_ = nullptr; +}; + +namespace NN { +Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); +} // namespace NN +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ diff --git a/mindspore/lite/src/expression/ops/tile.cc b/mindspore/lite/src/expression/ops/tile.cc new file mode 100644 index 00000000000..4008da12570 --- /dev/null +++ b/mindspore/lite/src/expression/ops/tile.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/tile.h" +#include +#include "src/expression/ops.h" +#include "nnacl/base/tile_base.h" + +namespace mindspore { +namespace lite { +TileM::TileM(const std::vector &multiples) : Node() { + expr()->SetSize(C2NUM); + TileParameter *param = reinterpret_cast(calloc(1, sizeof(TileParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << " cannot allocate ConvParameter"; + return; + } + SetOpParam(param); + set_name(UniqueName("Tile")); + set_primitive(schema::PrimitiveType_TileFusion); + Node::CreateConstTensor(C1NUM, {static_cast(multiples.size())}, kNumberTypeInt32, KHWC, "axis", + multiples.data()); +} + +int TileM::UnPopulate(const std::unique_ptr &cnode) { + auto tile_param = reinterpret_cast(OpParam()); + auto prim = new (std::nothrow) schema::TileFusionT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + for (size_t i = 0; i < tile_param->dims_size_; i++) { + prim->dims.push_back(tile_param->dims_[i]); + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +namespace NN { +Node *Tile(const std::vector &multiples) { + auto node = new (std::nothrow) TileM(multiples); + if (node == nullptr) { + MS_LOG(ERROR) << "Cannot allocate tile node"; + } + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/tile.h b/mindspore/lite/src/expression/ops/tile.h new file mode 100644 index 00000000000..0d793c0e269 --- /dev/null +++ b/mindspore/lite/src/expression/ops/tile.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ + +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class TileM : public Node { + public: + TileM() = default; + explicit TileM(const std::vector &multiples); + Param *weight() override { return input(1)->node()->data(); } + int UnPopulate(const std::unique_ptr &cnode) override; +}; + +namespace NN { +Node *Tile(const std::vector &multiples); +} +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ diff --git a/mindspore/lite/src/expression/ops/transpose.cc b/mindspore/lite/src/expression/ops/transpose.cc new file mode 100644 index 00000000000..22e20b061b4 --- /dev/null +++ b/mindspore/lite/src/expression/ops/transpose.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops/transpose.h" +#include +#include "nnacl/transpose.h" +#include "inner/model_generated.h" +#include "src/expression/import.h" + +namespace mindspore { +namespace lite { +TransposeM::TransposeM(const std::vector &vector) { + auto param = calloc(1, sizeof(TransposeParameter)); + if (param == nullptr) { + MS_LOG(ERROR) << "Cannot allocate transpose parameter"; + return; + } + SetOpParam(param); + expr()->SetSize(C2NUM); + set_primitive(schema::PrimitiveType_Transpose); + std::vector dims = {static_cast(vector.size())}; + set_name(UniqueName("Transpose")); + CreateConstTensor(C1NUM, dims, kNumberTypeInt32, KHWC, "axis", vector.data()); +} + +std::vector TransposeM::Invert(const std::vector &vector) { + std::vector res; + for (size_t i = 0; i < vector.size(); i++) { + int idx = static_cast(i); + auto val = std::find_if(vector.begin(), vector.end(), [idx](int x) { return (x == idx) ? true : false; }); + if (val == vector.end()) { + MS_LOG(ERROR) << "Wrong index for " << idx; + return {}; + } + res.push_back(std::distance(vector.begin(), val)); + } + return res; +} + +std::vector TransposeM::Grad(EXPR *yt) { + auto tensor = input(1)->node(); + auto data = tensor->data(); + auto vec = data->Extract(); + auto invert = Invert(vec); + auto tran = new (std::nothrow) TransposeM(invert); + if (tran == nullptr) { + MS_LOG(ERROR) << "Cannot allocate transpose grad"; + return {}; + } + tran->set_name(kGradName + "/" + name() + "/" + tran->name()); + PushOp(tran); + auto grad = (*tran)({yt}); + return grad; +} + +int TransposeM::UnPopulate(const std::unique_ptr &cnode) { + auto prim = new (std::nothrow) schema::TransposeT; + if (prim == nullptr) { + MS_LOG(ERROR) << "Cannot allocate primitive"; + return RET_ERROR; + } + cnode->primitive->value.value = prim; + return RET_OK; +} + +static ImportReg reg(schema::PrimitiveType_Transpose, ReturnNode); + +namespace NN { +Node *Transpose(const std::vector &permute) { + auto node = new (std::nothrow) TransposeM(permute); + return node; +} +} // namespace NN +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops/transpose.h b/mindspore/lite/src/expression/ops/transpose.h new file mode 100644 index 00000000000..d4c9a1c11e8 --- /dev/null +++ b/mindspore/lite/src/expression/ops/transpose.h @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ + +#include +#include +#include "src/expression/node.h" +#include "inner/model_generated.h" + +namespace mindspore { +namespace lite { +class TransposeM : public Node { + public: + TransposeM() = default; + explicit TransposeM(const std::vector &vector); + static EXPR *TransposeCHW2HWC(EXPR *in) { + std::vector res = {0, 2, 3, 1}; + auto trans = new (std::nothrow) TransposeM(res); + if (trans == nullptr) { + return nullptr; + } + return (*trans)({in}).front(); + } + static EXPR *TransposeHWC2CHW(EXPR *in) { + std::vector res = {0, 3, 1, 2}; + auto trans = new (std::nothrow) TransposeM(res); + if (trans == nullptr) { + return nullptr; + } + return (*trans)({in}).front(); + } + int UnPopulate(const std::unique_ptr &cnode) override; + std::vector Grad(EXPR *yt) override; + + private: + std::vector Invert(const std::vector &vec); +}; + +namespace NN { +Node *Transpose(const std::vector &permute); +} // namespace NN +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ diff --git a/mindspore/lite/src/expression/ops_utils.cc b/mindspore/lite/src/expression/ops_utils.cc new file mode 100644 index 00000000000..e63b903daea --- /dev/null +++ b/mindspore/lite/src/expression/ops_utils.cc @@ -0,0 +1,275 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/ops_utils.h" +#include +#include + +namespace mindspore { +namespace lite { +enum class State { + SAME, + X_ONE, + Y_ONE, +}; + +bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { + if (x_shape.size() != y_shape.size()) { + return false; + } + + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape.at(i) != y_shape.at(i)) { + return false; + } + } + + return true; +} + +void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, + std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { + MS_ASSERT(grad_x_reduce_idx != nullptr); + MS_ASSERT(grad_y_reduce_idy != nullptr); + const size_t n = reverse_x.size(); + if (reverse_y.size() < n) { + MS_LOG_ERROR << "The size of reverse_y is less than the size of reverse_x."; + } + for (size_t i = 0; i < n; ++i) { + State curr = State::SAME; + const int x_i = reverse_x[i]; + const int y_i = reverse_y[i]; + const int reduce_idx = (n - 1 - i); + if (x_i == y_i) { + curr = State::SAME; + } else if (x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + curr = State::X_ONE; + } else if (y_i == 1) { + grad_y_reduce_idy->push_back(reduce_idx); + curr = State::Y_ONE; + } else { + MS_LOG_ERROR << "not compatible shape input for BroadcastGradientArgs"; + } + if (curr == State::SAME && x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + grad_y_reduce_idy->push_back(reduce_idx); + continue; + } + } + + std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); + std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); +} + +std::vector> BroadcastGradientArgs::operator()() { + std::vector> input_dim(kInNum); + input_dim[0] = dim0_; + input_dim[1] = dim1_; + auto same_shape = CompareShape(dim0_, dim1_); + if (same_shape) { + return {{}, {}}; + } + + std::vector reverse_x; + std::vector reverse_y; + + (void)std::transform(dim0_.rbegin(), dim0_.rend(), std::back_inserter(reverse_x), [](const int &v) { return v; }); + (void)std::transform(dim1_.rbegin(), dim1_.rend(), std::back_inserter(reverse_y), [](const int &v) { return v; }); + + if (reverse_x.size() > reverse_y.size()) { + reverse_y.resize(reverse_x.size(), 1); + } else { + reverse_x.resize(reverse_y.size(), 1); + } + + std::vector grad_x_reduce_idx; + std::vector grad_y_reduce_idy; + ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); + return {grad_x_reduce_idx, grad_y_reduce_idy}; +} + +void DynamicBroadcastGradientArgs::AddElementToGradReduceIdx(std::vector> *grad_reduce_idx, + std::vector current_is_one, bool none_is_one, + const size_t largest_rank, size_t j) { + for (size_t i = 0; i < kInNum; ++i) { + if (current_is_one[i] && !none_is_one) { + (void)(*grad_reduce_idx)[i].emplace_back(largest_rank - 1 - j); + } + } +} + +void DynamicBroadcastGradientArgs::UpdatePreIsOne(std::vector *prev_is_one, std::vector current_is_one) { + for (size_t i = 0; i < kInNum; ++i) { + (*prev_is_one)[i] = current_is_one[i]; + } +} + +std::vector> DynamicBroadcastGradientArgs::GetGradientIndices( + const std::vector> &reverse_shape, const size_t largest_rank) { + std::vector> grad_reduce_idx(kInNum); + // indices of j-th component of each input. + std::vector prev_is_one(kInNum); + std::vector current_is_one(kInNum); + for (size_t i = 0; i < kInNum; ++i) { + prev_is_one[i] = false; + current_is_one[i] = false; + } + + bool set_one = false; + for (size_t j = 0; j < largest_rank; ++j) { + int output_dim = -1; + bool output_dim_set = false; + bool none_is_one = true; + // Find which indices are 1. + for (size_t i = 0; i < kInNum; ++i) { + if (reverse_shape[i][j] == 1) { + current_is_one[i] = true; + none_is_one = false; + } else { + current_is_one[i] = false; + if (!output_dim_set || reverse_shape[i][j] == static_cast(output_dim)) { + output_dim = reverse_shape[i][j]; + output_dim_set = true; + } else { + std::cout << "Input[0] and input[1] Cannot broadcast!"; + } + } + } + // All dimensions are 1. + if (!output_dim_set) { + for (size_t i = 0; i < kInNum; ++i) { + (void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j); + } + continue; + } else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) { + AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); + } else { + AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); + } + set_one = true; + UpdatePreIsOne(&prev_is_one, current_is_one); + } + return grad_reduce_idx; +} + +std::vector> DynamicBroadcastGradientArgs::CalculateOutput(const std::vector> &x) { + std::vector> grad_reduce_idx(kInNum); + bool all_equal = true; + size_t largest_rank = 0; + for (size_t i = 0; i < kInNum; ++i) { + if (x[i] != x[0]) { + all_equal = false; + } + if (x[i].size() > largest_rank) { + largest_rank = x[i].size(); + } + } + if (all_equal) { + return grad_reduce_idx; + } + + // Reverse input the shapes + std::vector> reverse_shape(kInNum); + for (size_t i = 0; i < kInNum; ++i) { + reverse_shape[i] = x[i]; + std::reverse(reverse_shape[i].begin(), reverse_shape[i].end()); + } + + // 1-extend and align all vectors. + for (size_t i = 0; i < kInNum; ++i) { + if (reverse_shape[i].size() < largest_rank) { + reverse_shape[i].resize(largest_rank, 1); + } + } + grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank); + return grad_reduce_idx; +} + +std::vector> DynamicBroadcastGradientArgs::SetOutputValue( + const std::vector> &grad_reduce_idx, const std::vector> &input_dim) { + std::vector> output(kInNum); + for (size_t index = 0; index < kInNum; ++index) { + auto idx_num = grad_reduce_idx[index].size(); + for (size_t k = 0; k < idx_num; ++k) { + output[index].push_back(grad_reduce_idx[index][idx_num - 1 - k]); + } + if (idx_num == 0) { + auto input_num = input_dim[index].size(); + for (size_t k = 0; k < input_num; ++k) { + output[index].push_back(k); + } + } + } + return output; +} + +std::vector> DynamicBroadcastGradientArgs::operator()() { + std::vector> input_dim(kInNum); + input_dim[0] = dim0_; + input_dim[1] = dim1_; + auto grad_reduce_idx = CalculateOutput(input_dim); + auto output = SetOutputValue(grad_reduce_idx, input_dim); + return output; +} + +std::vector VectorDiv::operator()(const std::vector &x, const std::vector &d) { + if (d.size() != x.size()) { + MS_LOG(ERROR) << "x and divider must have same size"; + return {}; + } + std::vector res; + for (size_t i = 0; i < d.size(); i++) { + auto x_value = x.at(i); + auto d_value = d.at(i); + if (d_value == 0) { + MS_LOG(ERROR) << "Divisor is zero"; + return {}; + } + if ((x_value % d_value) != 0) { + MS_LOG(ERROR) << "x and d and not dividable"; + } + auto r = x_value / d_value; + res.push_back(r); + } + return res; +} + +std::vector ShapeReduce::operator()(const std::vector &x_shape, const std::vector &axis) { + int x_rank = x_shape.size(); + std::set axis_set; + + auto min = -x_rank; + auto max = x_rank - 1; + for (auto &elem : axis) { + if (elem > max || elem < min) { + MS_LOG(ERROR) << "illegal axis value"; + return {}; + } + axis_set.insert(elem); + } + std::vector res; + for (int i = 0; i < x_rank; i++) { + if (axis_set.count(i) || axis_set.count(i - x_rank)) { + res.push_back(1); + } else { + res.push_back(x_shape.at(i)); + } + } + return res; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/ops_utils.h b/mindspore/lite/src/expression/ops_utils.h new file mode 100644 index 00000000000..6c62de11cc7 --- /dev/null +++ b/mindspore/lite/src/expression/ops_utils.h @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ + +#include "include/api/cfg.h" +#include "src/expression/net.h" +#include "vector" + +namespace mindspore { +namespace lite { +class BroadcastGradientArgs { + public: + BroadcastGradientArgs(const std::vector &dim0, const std::vector &dim1) : dim0_(dim0), dim1_(dim1) {} + std::vector> operator()(); + + private: + static const int kInNum = 2; + const std::vector &dim0_; + const std::vector &dim1_; +}; + +class DynamicBroadcastGradientArgs { + public: + DynamicBroadcastGradientArgs(const std::vector &dim0, const std::vector &dim1) : dim0_(dim0), dim1_(dim1) {} + std::vector> operator()(); + + private: + void AddElementToGradReduceIdx(std::vector> *grad_reduce_idx, std::vector current_is_one, + bool none_is_one, const size_t largest_rank, size_t j); + void UpdatePreIsOne(std::vector *prev_is_one, std::vector current_is_one); + std::vector> GetGradientIndices(const std::vector> &reverse_shape, + const size_t largest_rank); + std::vector> CalculateOutput(const std::vector> &x); + std::vector> SetOutputValue(const std::vector> &grad_reduce_idx, + const std::vector> &input_dim); + static const int kInNum = 2; + const std::vector &dim0_; + const std::vector &dim1_; +}; + +class VectorDiv { + public: + VectorDiv() {} + std::vector operator()(const std::vector &x, const std::vector &d); +}; + +class ShapeReduce { + public: + ShapeReduce() {} + std::vector operator()(const std::vector &x_shape, const std::vector &axis); +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ diff --git a/mindspore/lite/src/expression/param.cc b/mindspore/lite/src/expression/param.cc new file mode 100644 index 00000000000..9f316e5aabf --- /dev/null +++ b/mindspore/lite/src/expression/param.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/expression/param.h" +#include +#include +#include +#include "include/errorcode.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +constexpr float kZero = 0.0f; +constexpr float kOne = 1.0f; + +namespace mindspore { +namespace lite { +int Param::Fill(Mode mode) { + std::default_random_engine engine{static_cast(0)}; + std::vector data(size_); + switch (mode) { + case NORMAL: { + constexpr float scale = 0.01; + std::normal_distribution n{0, 1}; + std::generate_n(data.begin(), size_, [&]() { return n(engine); }); + (void)std::transform(data.begin(), data.end(), data.begin(), [](float x) { return x * scale; }); + break; + } + case UNIFORM: { + constexpr float scale = 0.07; + std::uniform_real_distribution u{-1.0, 1.0}; + std::generate_n(data.begin(), size_, [&]() { return u(engine) * scale; }); + break; + } + case ZEROS: + std::fill_n(data.begin(), size_, kZero); + break; + case ONES: + std::fill_n(data.begin(), size_, kOne); + break; + case NOT_SUPPORTED: + return RET_ERROR; + } + Copy(data); + return RET_OK; +} + +Param::Mode Param::String2Enum(std::string mode) { + (void)std::transform(mode.begin(), mode.end(), mode.begin(), ::tolower); + if (mode == "normal") return NORMAL; + if (mode == "uniform") return UNIFORM; + if (mode == "ones") return ONES; + if (mode == "zeors") return ZEROS; + return NOT_SUPPORTED; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/param.h b/mindspore/lite/src/expression/param.h new file mode 100644 index 00000000000..201e69fc7b5 --- /dev/null +++ b/mindspore/lite/src/expression/param.h @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ + +#include +#include +#include +#include + +namespace mindspore { +namespace lite { +class Param { + public: + enum Mode { NORMAL, UNIFORM, ONES, ZEROS, NOT_SUPPORTED }; + int Fill(Mode type); + static Mode String2Enum(std::string); + std::vector &data() { return data_; } + size_t Load(std::string file_name, size_t offset = 0) { return data_.size() * sizeof(float); } + size_t Load(std::ifstream &s, int offset = 0) { return data_.size() * sizeof(float); } + void SetSize(size_t size) { size_ = size; } + template + void Copy(const T *data, size_t size) { + auto cast_data = reinterpret_cast(data); + data_ = decltype(data_)(cast_data, cast_data + size * sizeof(T) / sizeof(uint8_t)); + } + template + void Copy(const std::vector data) { + Copy(data.data(), data.size()); + } + + template + std::vector Extract() { + T *num = reinterpret_cast(data_.data()); + std::vector res(num, num + data_.size() / sizeof(T)); + return res; + } + + private: + size_t size_; + std::vector data_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ diff --git a/mindspore/lite/src/expression/sequential.cc b/mindspore/lite/src/expression/sequential.cc new file mode 100644 index 00000000000..5f3a8a76486 --- /dev/null +++ b/mindspore/lite/src/expression/sequential.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/expression/sequential.h" + +namespace mindspore { +namespace lite { +void Sequential::Add(Node *node) { PushOp(node); } + +std::vector Sequential::construct(const std::vector &inputs) { + auto x = inputs; + for (auto &node : ops_) { + x = (*node)({x.front()}); + } + return x; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/expression/sequential.h b/mindspore/lite/src/expression/sequential.h new file mode 100644 index 00000000000..9b1a69e52cd --- /dev/null +++ b/mindspore/lite/src/expression/sequential.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ +#define MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ +#include +#include "src/expression/net.h" + +namespace mindspore { +namespace lite { +class Sequential : public Net { + public: + std::vector construct(const std::vector &inputs) override; + void Add(Node *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_1x1_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_1x1_fp32.cc index 3e751d19e67..b91de269a0f 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_1x1_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_1x1_fp32.cc @@ -25,6 +25,7 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { FreeTmpBuffer(); + if (matmul_param_ != nullptr) { delete matmul_param_; matmul_param_ = nullptr; @@ -185,7 +186,7 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; float *thread_pack_input = pack_input_ + task_id * row_tile_ * matmul_param_->deep_; - float *thread_output_ptr; + float *thread_output_ptr = nullptr; if (out_tensors()[0]->format() != NC4HW4) { thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; } else { diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_depthwise_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_depthwise_fp32.cc index a0777e21d76..d169076092b 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_depthwise_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_depthwise_fp32.cc @@ -127,10 +127,12 @@ int ConvolutionDepthwiseCPUKernel::MallocWeightBiasData() { } } CHECK_LESS_RETURN(MAX_MALLOC_SIZE, channel * sizeof(float)); - bias_data_ = malloc(channel * sizeof(float)); if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; + bias_data_ = malloc(channel * sizeof(float)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } } memset(bias_data_, 0, channel * sizeof(float)); return RET_OK; diff --git a/mindspore/lite/test/config_level0/cropped_size.cfg b/mindspore/lite/test/config_level0/cropped_size.cfg index 383224406cc..44faa904f3b 100644 --- a/mindspore/lite/test/config_level0/cropped_size.cfg +++ b/mindspore/lite/test/config_level0/cropped_size.cfg @@ -1,2 +1,2 @@ Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!! -1034600 +1040696 diff --git a/mindspore/lite/test/config_level0/models_ms_train.cfg b/mindspore/lite/test/config_level0/models_ms_train.cfg index ee65f9a0070..1fc89f2ee4d 100644 --- a/mindspore/lite/test/config_level0/models_ms_train.cfg +++ b/mindspore/lite/test/config_level0/models_ms_train.cfg @@ -48,5 +48,7 @@ vae unified_api code_example train_lenet code_example train_lenet_java code_example +lenet expression +mobilenetv2 expression noarm32 # LAST #test_resize inputShapes 16,10,10,1:16,10,10,1 0.5 diff --git a/mindspore/lite/test/config_level1/cropped_size.cfg b/mindspore/lite/test/config_level1/cropped_size.cfg index ea703e157c7..fbcbc160bb7 100644 --- a/mindspore/lite/test/config_level1/cropped_size.cfg +++ b/mindspore/lite/test/config_level1/cropped_size.cfg @@ -1,2 +1,2 @@ Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review. -1034600 +1040696 diff --git a/mindspore/lite/test/st/scripts/run_net_train.sh b/mindspore/lite/test/st/scripts/run_net_train.sh index d27d443c8ee..edf7e511d8d 100755 --- a/mindspore/lite/test/st/scripts/run_net_train.sh +++ b/mindspore/lite/test/st/scripts/run_net_train.sh @@ -1,6 +1,6 @@ #!/bin/bash source ./scripts/base_functions.sh -version=1.5.0 +version=1.6.1 # Run Export on x86 platform and create output test files: docker_image=mindspore_build:210301 @@ -69,6 +69,9 @@ function Run_Converter() { model_prefix="${line_array[0]}_head" model_name=${line_array[0]}'_head' fi + if [[ "${expression}" == "1" ]]; then + model_prefix="${model_name}_fwd" + fi if [[ ${check_convert} == "1" ]]; then ms_file=$ms_models_path'/'$model_name'.ms' if [ -f "$ms_file" ]; then @@ -80,7 +83,7 @@ function Run_Converter() { { echo ${model_name} >> "${run_converter_log_file}" echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_prefix}'.mindir --outputFile='${ms_models_path}'/'${model_name}' --trainModel=true' ${WEIGHT_QUANT} >> "${run_converter_log_file}" - ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_prefix}.mindir --outputFile=${ms_models_path}/${model_name} --trainModel=true ${WEIGHT_QUANT} + ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_prefix}.mindir --outputFile=${ms_models_path}/${model_name} --trainModel=true ${no_opt} ${WEIGHT_QUANT} if [ $? = 0 ]; then converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else @@ -118,9 +121,12 @@ function should_run_example() { continue fi if [[ $model_name == "$1" ]]; then - if [[ ${line_array[1]} == "code_example" ]]; then + if [[ "${line_array[1]}" == "code_example" ]]; then ret=1 fi + if [[ "${line_array[2]}" == "expression" ]]; then + ret=2 + fi fi done < ${models_ms_train_config} return $ret @@ -146,6 +152,9 @@ function parse_line() { enable_transfer=0 suffix_print="" check_convert=0 + do_api=false + no_opt="--NoFusion=false" + expression=0 model_name=${line_array[0]}_train while [[ $i < ${#line_array[@]} ]] ; do case ${line_array[i]} in @@ -189,6 +198,12 @@ function parse_line() { i=$(($i+1)) inputShapes=${line_array[i]} ;; + "expression") + model_name="${line_array[0]}_expr" + do_api=true + no_opt="--NoFusion=true" + expression=1 + ;; *) check=`echo "${line_array[i]}" | grep -E '^\-?[0-9]*\.?[0-9]+$'` if [ "${check}" != "" ] ; then @@ -214,12 +229,14 @@ function Run_x86() { continue fi local model_prefix=${line_array[0]} + local bb_model_file="" local log_suffix="_train" parse_line x86 if [[ "$?" == "1" ]]; then continue; fi + if [[ "${expression}" == "1" ]]; then + model_prefix=${model_prefix}_expr + fi local model_file="${ms_models_path}/${model_name}.ms" - local bb_model_file="" - local export_file="${ms_models_path}/${model_name}_tod" local inference_file="${ms_models_path}/${model_name}_infer" if [[ "${enable_transfer}" == "1" ]]; then model_file="${ms_models_path}/${model_prefix}_head.ms" @@ -247,7 +264,8 @@ function Run_x86() { --exportFile=${export_file} \ --virtualBatch=${virtual_batch} \ --inputShapes=${inputShapes} \ - --lossName=${loss_name} " >> ${run_x86_log_file} + --lossName=${loss_name} \ + --unifiedApi=${do_api}" >> ${run_x86_log_file} ${run_valgrind} ./tools/benchmark_train/benchmark_train \ --modelFile=${model_file} \ --bbModelFile=${bb_model_file} \ @@ -257,7 +275,8 @@ function Run_x86() { --exportFile=${export_file} \ --virtualBatch=${virtual_batch} \ --inputShapes=${inputShapes} \ - --lossName=${loss_name} >> "${run_x86_log_file}" + --lossName=${loss_name} \ + --unifiedApi=${do_api} >> "${run_x86_log_file}" if [ $? = 0 ]; then run_result='x86'${log_suffix}': '${model_name}''${suffix_print}' pass'; echo ${run_result} >> ${run_benchmark_train_result_file} else @@ -327,24 +346,27 @@ function Run_arm() { echo "cd ${tmp_dir}" > ${adb_cmd_file} echo 'chmod 777 benchmark_train' >> ${adb_cmd_file} adb -s ${device_id} shell < ${adb_cmd_file} - local fail=0 # Run mindir converted train models: while read line; do local line_array + local bb_model_file="" LFS=" " read -r -a line_array <<< ${line} if [[ ${line_array[0]} == \#* || ${line_array[0]} == "" ]]; then continue fi + parse_line $1 + if [[ "$?" == "1" ]]; then continue; fi local model_prefix=${line_array[0]} + if [[ ${expression} == "1" ]]; then + model_prefix=${model_prefix}_expr + fi local run_result="" local log_suffix="_train" - parse_line $1 if [[ "$?" == "1" ]]; then continue; fi local export_file="${tmp_dir}/${model_name}_tod" local inference_file="${tmp_dir}/${model_name}_infer" local model_file="${model_name}.ms" - local bb_model_file="" if [[ "${enable_transfer}" == "1" ]]; then model_file="${model_prefix}_head.ms" bb_model_file="${model_prefix}_bb.ms" @@ -388,7 +410,8 @@ function Run_arm() { --exportFile=${export_file} \ --virtualBatch=${virtual_batch} \ --inputShapes=${inputShapes} \ - --lossName=${loss_name} + --lossName=${loss_name} \ + --unifiedApi=${do_api} ENDM ) echo "${adb_cmd}" >> ${run_arm_log_file} @@ -443,11 +466,16 @@ function Run_CodeExamples() { should_run_example "unified_api" should_run=$? + local expression_flag="" + if [[ "$should_run" == "2" ]]; then + expression_flag="-x" + should_run=1 + fi if [[ "$should_run" == "1" ]]; then cd ${basepath}/../../examples/unified_api || exit 1 chmod 777 ./prepare_and_run.sh chmod 777 ./*/*.sh - ./prepare_and_run.sh -D ${datasets_path}/mnist -r ${tarball_path} -t ${target} -m ${models_path}/code_example.mindir -e 1 >> ${run_code_examples_log_file} + ./prepare_and_run.sh -D ${datasets_path}/mnist -r ${tarball_path} -t ${target} -m ${models_path}/code_example.mindir -e 1 ${expression_flag} >> ${run_code_examples_log_file} if [ "$?" != "0" ]; then echo "Unified API prepare_and_run.sh failed" exit 1 @@ -492,6 +520,10 @@ while getopts "r:c:m:d:i:e:vt:q:D:M:l:" opt; do models_path=${OPTARG} echo "models_path is ${OPTARG}" ;; + c) + models_ms_train_config=${OPTARG} + echo "models_ms_train_config ${models_ms_train_config}" + ;; i) train_io_path=${OPTARG} echo "train_io_path is ${OPTARG}" @@ -539,7 +571,9 @@ config_folder="config_level0" if [[ ${level} == "level1" ]];then config_folder="config_level1" fi -models_ms_train_config=${basepath}/../${config_folder}/models_ms_train.cfg +if [[ "${models_ms_train_config}" == "" ]]; then + models_ms_train_config=${basepath}/../${config_folder}/models_ms_train.cfg +fi if [[ $train_io_path == "" ]]; then echo "train_io path is empty" diff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt index fe4d13bfcaf..986ffac5770 100644 --- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt +++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt @@ -5,10 +5,23 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ) + +set(TEST_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/main.cc + ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc + ) + +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") + set(TEST_SRC + ${TEST_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/net_runner.cc + ) +endif() + add_executable(benchmark_train - ${CMAKE_CURRENT_SOURCE_DIR}/main.cc - ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc + ${TEST_SRC} ${COMMON_SRC}) + if(WIN32) add_dependencies(benchmark_train fbs_src mindspore-lite_static mindspore-lite-train_static) else() diff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc new file mode 100644 index 00000000000..9b63d29ff6b --- /dev/null +++ b/mindspore/lite/tools/benchmark_train/net_runner.cc @@ -0,0 +1,371 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/benchmark_train/net_runner.h" +#include "tools/benchmark_train/net_train.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/api/types.h" +#include "include/api/context.h" +#include "include/api/serialization.h" +#include "include/api/callback/loss_monitor.h" +#include "include/api/metrics/accuracy.h" +#include "include/api/callback/ckpt_saver.h" +#include "include/api/callback/train_accuracy.h" +#include "include/api/callback/lr_scheduler.h" +#include "include/dataset/datasets.h" +#include "include/dataset/vision_lite.h" +#include "include/dataset/transforms.h" +#include "include/api/cfg.h" +#include "include/api/net.h" + +using mindspore::AccuracyMetrics; +using mindspore::Model; +using mindspore::TrainAccuracy; +using mindspore::TrainCallBack; +using mindspore::TrainCallBackData; +using mindspore::dataset::Dataset; +using mindspore::dataset::Mnist; +using mindspore::dataset::SequentialSampler; +using mindspore::dataset::TensorOperation; +using mindspore::dataset::transforms::TypeCast; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::Resize; + +constexpr int kNCHWCDim = 2; +constexpr int kPrintTimes = 100; +constexpr float kBetta1 = 0.9f; +constexpr float kBetta2 = 0.999f; + +class Rescaler : public mindspore::TrainCallBack { + public: + explicit Rescaler(float scale) : scale_(scale) { + if (scale_ == 0) { + scale_ = 1.0; + } + } + ~Rescaler() override = default; + void StepBegin(const mindspore::TrainCallBackData &cb_data) override { + auto inputs = cb_data.model_->GetInputs(); + auto *input_data = reinterpret_cast(inputs.at(0).MutableData()); + for (int k = 0; k < inputs.at(0).ElementNum(); k++) input_data[k] /= scale_; + } + + private: + float scale_ = 1.0; +}; + +/* This is an example of a user defined Callback to measure memory and latency of execution */ +class Measurement : public mindspore::TrainCallBack { + public: + explicit Measurement(unsigned int epochs) + : time_avg_(std::chrono::duration(0)), epochs_(epochs) {} + ~Measurement() override = default; + void EpochBegin(const mindspore::TrainCallBackData &cb_data) override { + start_time_ = std::chrono::high_resolution_clock::now(); + } + mindspore::CallbackRetValue EpochEnd(const mindspore::TrainCallBackData &cb_data) override { + end_time_ = std::chrono::high_resolution_clock::now(); + auto time = std::chrono::duration(end_time_ - start_time_); + time_avg_ += time; + return mindspore::kContinue; + } + void End(const mindspore::TrainCallBackData &cb_data) override { + if (epochs_ > 0) { + std::cout << "AvgRunTime: " << time_avg_.count() / epochs_ << " ms" << std::endl; + } + + struct mallinfo info = mallinfo(); + std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl; + } + + private: + std::chrono::time_point start_time_; + std::chrono::time_point end_time_; + std::chrono::duration time_avg_; + unsigned int epochs_; +}; + +NetRunner::~NetRunner() { + if (model_ != nullptr) { + delete model_; + } + if (graph_ != nullptr) { + delete graph_; + } +} + +mindspore::Status NetRunner::InitAndFigureInputs() { + auto context = std::make_shared(); + auto cpu_context = std::make_shared(); + cpu_context->SetEnableFP16(enable_fp16_); + context->MutableDeviceInfo().push_back(cpu_context); + + graph_ = new (std::nothrow) mindspore::Graph(mindspore::Graph::Type::kExpressionGraph); + if (graph_ == nullptr) { + std::cout << "Cannot allocate graph" << std::endl; + return mindspore::kLiteMemoryFailed; + } + auto status = mindspore::Serialization::Load(ms_file_, mindspore::kMindIR, graph_); + if (status != mindspore::kSuccess) { + std::cout << "Error " << status << " during serialization of graph " << ms_file_; + return status; + } + auto net = std::make_unique(*graph_); + auto input_shape = net->InputShape(0); + auto label_shape = net->OutputShape(0); + auto inputM = mindspore::NN::Input(input_shape); + auto labelM = mindspore::NN::Input(label_shape); + auto label = labelM->Create("label"); + auto input = inputM->Create("input"); + + auto cfg = std::make_shared(); + if (enable_fp16_) { + cfg.get()->optimization_level_ = mindspore::kO2; + } + + model_ = new (std::nothrow) mindspore::Model(); + if (model_ == nullptr) { + std::cout << "model allocation failed" << std::endl; + return mindspore::kLiteMemoryFailed; + } + mindspore::SoftMaxCrossEntropyCfg softmax_ce_cfg; + softmax_ce_cfg.reduction = "none"; + auto netWithLoss = mindspore::NN::GraphWithLoss(graph_, mindspore::NN::SoftmaxCrossEntropy(softmax_ce_cfg)); + mindspore::AdamConfig AdamCfg; + AdamCfg.beta1_ = kBetta1; + AdamCfg.beta2_ = kBetta2; + AdamCfg.eps_ = 1e-8; + AdamCfg.learning_rate_ = 1e-2; + auto optimizer = mindspore::NN::Adam(net->trainable_params(), AdamCfg); + status = model_->Build(mindspore::GraphCell(*netWithLoss), optimizer, {input, label}, context, cfg); + if (status != mindspore::kSuccess) { + std::cout << "Error " << status << " during build of model " << ms_file_ << std::endl; + return status; + } + delete graph_; + graph_ = nullptr; + auto inputs = model_->GetInputs(); + if (inputs.size() < 1) { + return mindspore::kLiteError; + } + auto nhwc_input_dims = inputs.at(0).Shape(); + batch_size_ = nhwc_input_dims.at(0); + h_ = nhwc_input_dims.at(1); + w_ = nhwc_input_dims.at(kNCHWCDim); + return mindspore::kSuccess; +} + +int NetRunner::CompareOutput(const std::vector &outputs) { + std::cout << "================ Comparing Forward Output data ================" << std::endl; + float total_bias = 0; + int total_size = 0; + bool has_error = false; + int i = 1; + for (auto &tensor : outputs) { + std::cout << "output is tensor " << tensor.Name() << "\n"; + auto output = tensor.Data(); + size_t size; + std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; + auto bin_buf = std::unique_ptr(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size)); + if (bin_buf == nullptr) { + MS_LOG(ERROR) << "ReadFile return nullptr"; + std::cout << "ReadFile return nullptr" << std::endl; + return mindspore::kLiteNullptr; + } + if (size != tensor.DataSize()) { + MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() + << ", read size: " << size; + std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() + << ", read size: " << size << std::endl; + return mindspore::kLiteError; + } + float bias = mindspore::lite::NetTrain::CompareData(bin_buf.get(), tensor.ElementNum(), + reinterpret_cast(output.get())); + if (bias >= 0) { + total_bias += bias; + total_size++; + } else { + has_error = true; + break; + } + i++; + } + + if (!has_error) { + float mean_bias; + if (total_size != 0) { + mean_bias = total_bias / total_size * kPrintTimes; + } else { + mean_bias = 0; + } + + std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" + << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; + std::cout << "=======================================================" << std::endl << std::endl; + + if (mean_bias > this->flags_->accuracy_threshold_) { + MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%"; + std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; + return mindspore::kLiteError; + } else { + return mindspore::kSuccess; + } + } else { + MS_LOG(ERROR) << "Error in CompareData"; + std::cout << "Error in CompareData" << std::endl; + std::cout << "=======================================================" << std::endl << std::endl; + return mindspore::kSuccess; + } +} + +void NetRunner::CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out) { + constexpr int kPrintLen = 4; + int tensor_size = tensor.ElementNum(); + const void *data = tensor.Data().get(); + const float *fdata = reinterpret_cast(data); + mindspore::DataType type = tensor.DataType(); + std::cout << node_type << " " << in_out << id << std::endl; + std::cout << "tensor name: " << tensor.Name() << std::endl; + if ((tensor_size) == 0 || (data == nullptr)) { + std::cout << "Empty tensor" << std::endl; + return; + } + switch (type) { + case mindspore::DataType::kNumberTypeFloat32: + std::cout << "sum=" << mindspore::lite::TensorSum(data, tensor_size) << std::endl; + std::cout << "data: "; + for (int i = 0; i <= kPrintLen && i < tensor_size; i++) { + std::cout << static_cast(fdata[i]) << ", "; + } + std::cout << std::endl; + break; + case mindspore::DataType::kNumberTypeInt32: + std::cout << "sum=" << mindspore::lite::TensorSum(data, tensor_size) << std::endl; + break; + default: + std::cout << "unsupported type:" << static_cast(type) << std::endl; + break; + } +} + +int NetRunner::InitCallbackParameter() { + // after callback + after_call_back_ = [&](const std::vector &after_inputs, + const std::vector &after_outputs, + const mindspore::MSCallBackParam &call_param) { + if (after_inputs.empty()) { + MS_LOG(INFO) << "The num of after inputs is empty"; + } + if (after_outputs.empty()) { + MS_LOG(INFO) << "The num of after outputs is empty"; + } + if (flags_->layer_checksum_) { + for (size_t i = 0; i < after_inputs.size(); i++) { + CheckSum(after_inputs.at(i), call_param.node_type, i, "in"); + } + for (size_t i = 0; i < after_outputs.size(); i++) { + CheckSum(after_outputs.at(i), call_param.node_type, i, "out"); + } + std::cout << std::endl; + } + return true; + }; + return false; +} + +int NetRunner::RunOnce() { + auto inputs = model_->GetInputs(); + std::vector output; + auto status = LoadInput(&inputs); + if (status != mindspore::kSuccess) { + std::cout << "cannot load data"; + return status; + } + model_->SetTrainMode(true); + model_->RunStep(nullptr, nullptr); + model_->SetTrainMode(false); + model_->Predict(inputs, &output, nullptr, nullptr); + return CompareOutput(output); +} + +int NetRunner::LoadInput(std::vector *ms_inputs) { + auto status = ReadInputFile(ms_inputs); + if (status != mindspore::kSuccess) { + std::cout << "Read Input File error, " << status << std::endl; + MS_LOG(ERROR) << "Read Input File error, " << status; + return status; + } + return mindspore::kSuccess; +} + +int NetRunner::ReadInputFile(std::vector *ms_inputs) { + if (ms_inputs->empty()) { + std::cout << "no inputs to input" << std::endl; + return mindspore::kLiteError; + } + for (size_t i = 0; i < ms_inputs->size(); i++) { + auto cur_tensor = ms_inputs->at(i); + if (cur_tensor == nullptr) { + std::cout << "empty tensor " << i << std::endl; + MS_LOG(ERROR) << "empty tensor " << i; + } + size_t size; + std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; + auto bin_buf = std::unique_ptr(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size)); + if (bin_buf == nullptr) { + MS_LOG(ERROR) << "ReadFile return nullptr"; + std::cout << "ReadFile return nullptr" << std::endl; + return mindspore::kLiteNullptr; + } + auto tensor_data_size = cur_tensor.DataSize(); + if (size != tensor_data_size) { + std::cout << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size + << " ,file_name: " << file_name.c_str() << std::endl; + MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size + << " ,file_name: " << file_name.c_str(); + return mindspore::kLiteError; + } + auto input_data = cur_tensor.MutableData(); + memcpy(input_data, bin_buf.get(), tensor_data_size); + } + return mindspore::kSuccess; +} + +int NetRunner::Main() { + ms_file_ = flags_->model_file_; + InitCallbackParameter(); + auto status = InitAndFigureInputs(); + if (status != mindspore::kSuccess) { + std::cout << "failed to initialize network" << std::endl; + return status.StatusCode(); + } + return RunOnce(); +} + +int CallBack(mindspore::lite::NetTrainFlags *flags) { + NetRunner nr(flags); + return nr.Main(); +} + +int init = mindspore::lite::NetTrain::SetNr(CallBack); diff --git a/mindspore/lite/tools/benchmark_train/net_runner.h b/mindspore/lite/tools/benchmark_train/net_runner.h new file mode 100644 index 00000000000..243b94ef8db --- /dev/null +++ b/mindspore/lite/tools/benchmark_train/net_runner.h @@ -0,0 +1,81 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ +#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ + +#include +#include +#include +#include +#include +#include +#include "include/api/model.h" +#include "include/api/graph.h" +#include "include/api/types.h" +#include "include/api/status.h" +#include "include/api/metrics/accuracy.h" +#include "include/dataset/datasets.h" + +using mindspore::AccuracyMetrics; +using mindspore::dataset::Dataset; + +namespace mindspore::lite { +class NetTrainFlags; +} + +class NetRunner { + public: + int Main(); + explicit NetRunner(mindspore::lite::NetTrainFlags *flags) : flags_(flags) {} + bool ReadArgs(int argc, int8_t *argv[]); + virtual ~NetRunner(); + + private: + void Usage(); + mindspore::Status InitAndFigureInputs(); + void CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out); + int InitCallbackParameter(); + int TrainLoop(); + float CalculateAccuracy(int max_tests = 0); + float GetLoss() const; + int RunOnce(); + int CompareOutput(const std::vector &outputs); + int LoadInput(std::vector *ms_inputs); + int ReadInputFile(std::vector *ms_inputs); + + mindspore::Model *model_ = nullptr; + mindspore::Graph *graph_ = nullptr; + + std::shared_ptr train_ds_; + std::shared_ptr test_ds_; + std::shared_ptr acc_metrics_; + + std::string ms_file_ = ""; + std::string data_dir_ = ""; + unsigned int epochs_ = 10; + bool verbose_ = false; + bool enable_fp16_ = false; + int virtual_batch_ = -1; + int save_checkpoint_ = 0; + int batch_size_ = 32; + int h_ = 32; + int w_ = 32; + mindspore::lite::NetTrainFlags *flags_{nullptr}; + mindspore::MSKernelCallBack after_call_back_; +}; + +#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index fa2a9c9e858..f974d6107a6 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -24,6 +24,7 @@ #ifdef ENABLE_NEON #include #endif +#include "tools/benchmark_train/net_runner.h" #include "src/common/common.h" #include "include/ms_tensor.h" #include "include/context.h" @@ -49,14 +50,20 @@ constexpr int kCPUBindFlag2 = 2; constexpr int kCPUBindFlag1 = 1; static const int kTHOUSAND = 1000; -namespace { -float *ReadFileBuf(const char *file, size_t *size) { - if (file == nullptr) { +std::function NetTrain::nr_cb_ = nullptr; + +int NetTrain::SetNr(std::function param) { + nr_cb_ = param; + return 0; +} + +float *NetTrain::ReadFileBuf(const std::string file, size_t *size) { + if (file.empty()) { MS_LOG(ERROR) << "file is nullptr"; return nullptr; } MS_ASSERT(size != nullptr); - std::string real_path = RealPath(file); + std::string real_path = RealPath(file.c_str()); std::ifstream ifs(real_path); if (!ifs.good()) { MS_LOG(ERROR) << "file: " << real_path << " is not exist"; @@ -83,7 +90,6 @@ float *ReadFileBuf(const char *file, size_t *size) { return buf.release(); } -} // namespace int NetTrain::GenerateRandomData(mindspore::tensor::MSTensor *tensor) { auto input_data = tensor->MutableData(); @@ -832,7 +838,9 @@ int RunNetTrain(int argc, const char **argv) { std::cerr << flags.Usage() << std::endl; return RET_OK; } - + if (flags.unified_api_) { + return NetTrain::RunNr(&flags); + } NetTrain net_trainer(&flags); auto status = net_trainer.Init(); if (status != RET_OK) { diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index 8eb49bdceff..02d05dee8c0 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -53,8 +53,8 @@ constexpr float relativeTolerance = 1e-5; constexpr float absoluteTolerance = 1e-8; template -float TensorSum(void *data, int size) { - T *typed_data = reinterpret_cast(data); +float TensorSum(const void *data, int size) { + const T *typed_data = reinterpret_cast(data); float sum = 0.f; for (int i = 0; i < size; i++) { sum += static_cast(typed_data[i]); @@ -85,6 +85,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); + AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false); } ~NetTrainFlags() override = default; @@ -117,6 +118,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { std::vector> resize_dims_; std::string loss_name_ = ""; std::string inference_file_ = ""; + bool unified_api_ = false; }; class MS_API NetTrain { @@ -126,49 +128,19 @@ class MS_API NetTrain { int Init(); int RunNetTrain(); - - private: - // call GenerateInputData or ReadInputFile to init inputTensors - int LoadInput(Vector *ms_inputs); - void CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out); - // call GenerateRandomData to fill inputTensors - int GenerateInputData(std::vector *ms_inputs); - - int GenerateRandomData(mindspore::tensor::MSTensor *tensor); - - int ReadInputFile(std::vector *ms_inputs); - int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs, - bool check_accuracy = true); - - std::unique_ptr CreateAndRunNetworkForInference(const std::string &filename, - const Context &context); - - std::unique_ptr CreateAndRunNetworkForTrain(const std::string &filename, - const std::string &bb_filename, - const Context &context, const TrainCfg &train_cfg, - int epochs); - - int InitCallbackParameter(); - - int PrintResult(const std::vector &title, const std::map> &result); - - template - void PrintInputData(tensor::MSTensor *input) { - MS_ASSERT(input != nullptr); - static int i = 0; - auto inData = reinterpret_cast(input->MutableData()); - size_t tensorSize = input->ElementsNum(); - size_t len = (tensorSize < 20) ? tensorSize : 20; - std::cout << "InData" << i++ << ": "; - for (size_t j = 0; j < len; j++) { - std::cout << inData[j] << " "; + static float *ReadFileBuf(const std::string file, size_t *size); + static int SetNr(std::function param); + static int RunNr(NetTrainFlags *flags) { + if (nr_cb_ != nullptr) { + return nr_cb_(flags); } - std::cout << std::endl; + MS_LOG(WARNING) << "unified api was not tested"; + std::cout << "unified api was not tested"; + return RET_OK; } - // tensorData need to be converter first template - float CompareData(const float *refOutput, int size, T *msTensorData) { + static float CompareData(const float *refOutput, int size, const T *msTensorData) { size_t errorCount = 0; float meanError = 0; std::cout << "Out tensor size is: " << size << std::endl; @@ -219,6 +191,45 @@ class MS_API NetTrain { return meanError; } + private: + // call GenerateInputData or ReadInputFile to init inputTensors + int LoadInput(Vector *ms_inputs); + void CheckSum(mindspore::tensor::MSTensor *tensor, std::string node_type, int id, std::string in_out); + // call GenerateRandomData to fill inputTensors + int GenerateInputData(std::vector *ms_inputs); + + int GenerateRandomData(mindspore::tensor::MSTensor *tensor); + + int ReadInputFile(std::vector *ms_inputs); + int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs, + bool check_accuracy = true); + + std::unique_ptr CreateAndRunNetworkForInference(const std::string &filename, + const Context &context); + + std::unique_ptr CreateAndRunNetworkForTrain(const std::string &filename, + const std::string &bb_filename, + const Context &context, const TrainCfg &train_cfg, + int epochs); + + int InitCallbackParameter(); + + int PrintResult(const std::vector &title, const std::map> &result); + + template + void PrintInputData(tensor::MSTensor *input) { + MS_ASSERT(input != nullptr); + static int i = 0; + auto inData = reinterpret_cast(input->MutableData()); + size_t tensorSize = input->ElementsNum(); + size_t len = (tensorSize < 20) ? tensorSize : 20; + std::cout << "InData" << i++ << ": "; + for (size_t j = 0; j < len; j++) { + std::cout << inData[j] << " "; + } + std::cout << std::endl; + } + int MarkPerformance(const std::unique_ptr &session); int MarkAccuracy(const std::unique_ptr &session, bool enforce_accuracy = true); int CompareOutput(const session::LiteSession &lite_session); @@ -242,8 +253,8 @@ class MS_API NetTrain { } } #endif - NetTrainFlags *flags_; - + NetTrainFlags *flags_{nullptr}; + static std::function nr_cb_; // callback parameters uint64_t op_begin_ = 0; int op_call_times_total_ = 0; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index f182e4ed966..a5158e0c341 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -101,6 +101,7 @@ Flags::Flags() { "Whether to export MindIR pb. " "true | false", "false"); + AddFlag(&Flags::noFusionStr, "NoFusion", "Avoid fusion optimization true|false", "false"); } int Flags::InitInputOutputDataType() { @@ -359,6 +360,18 @@ int Flags::InitPreInference() { return RET_OK; } +int Flags::InitNoFusion() { + if (this->noFusionStr == "true") { + this->disableFusion = true; + } else if (this->noFusionStr == "false") { + this->disableFusion = false; + } else { + std::cerr << "INPUT ILLEGAL: NoFusion must be true|false " << std::endl; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} + int Flags::InitExportMindIR() { if (this->exportMindIR == "true") { this->export_mindir = true; @@ -495,12 +508,16 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "Init encrypt failed." << std::endl; return RET_INPUT_PARAM_INVALID; } - ret = InitPreInference(); if (ret != RET_OK) { std::cerr << "Init pre inference failed." << std::endl; return RET_INPUT_PARAM_INVALID; } + ret = InitNoFusion(); + if (ret != RET_OK) { + std::cerr << "Init no fusion failed." << std::endl; + return RET_INPUT_PARAM_INVALID; + } ret = InitExportMindIR(); if (ret != RET_OK) { @@ -509,6 +526,7 @@ int Flags::Init(int argc, const char **argv) { } return RET_OK; } + Flags::~Flags() { dec_key.clear(); encKeyStr.clear(); @@ -591,7 +609,7 @@ std::string GetStrFromConfigFile(const std::string &file, const std::string &tar } #ifdef _WIN32 - char *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit); + auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit); #else char *real_path = realpath(file.c_str(), resolved_path.get()); #endif diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 9da5bc3a62b..8651ff6ad22 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_ #include #include @@ -73,6 +73,8 @@ class Flags : public virtual mindspore::lite::FlagParser { int InitSaveFP16(); + int InitNoFusion(); + void InitAclDefaultOption(); int InitExportMindIR(); @@ -90,6 +92,7 @@ class Flags : public virtual mindspore::lite::FlagParser { TypeId outputDataType; std::string saveFP16Str = "off"; bool saveFP16 = false; + std::string noFusionStr = "false"; std::string inputDataTypeStr; std::string outputDataTypeStr; ParallelSplitConfig parallel_split_config_{}; @@ -134,4 +137,4 @@ std::string GetStrFromConfigFile(const std::string &file, const std::string &tar } // namespace converter } // namespace mindspore -#endif +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_