diff --git a/build.sh b/build.sh index 907e1c0e594..25ba93071c5 100755 --- a/build.sh +++ b/build.sh @@ -109,7 +109,7 @@ checkopts() ENABLE_GPU="off" # Process the options - while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt + while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:EnT:' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -282,6 +282,11 @@ checkopts() ENABLE_IBVERBS="on" echo "enable IBVERBS for parameter server" ;; + T) + check_on_off $OPTARG T + SUPPORT_TRAIN=$OPTARG + echo "support train on device " + ;; *) echo "Unknown option ${opt}!" usage diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 8b96e731ab0..c9d1ef18587 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -23,6 +23,7 @@ endif() if (SUPPORT_TRAIN) set(ANF_SRC + ${ANF_SRC} # ${CCSRC_DIR}/common/trans.cc # ${CCSRC_DIR}/utils/lite/base_ref_utils.cc # ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc @@ -40,14 +41,17 @@ if (SUPPORT_TRAIN) set(LITE_SRC ${LITE_SRC} ${ANF_SRC} - ${PASS_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc - ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc - ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc + # ${PASS_SRC} + # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary + ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary ) + else () set(LITE_SRC ${LITE_SRC} diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc index 6ec0ba8c545..b0d96fa3ee3 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc @@ -27,96 +27,143 @@ namespace mindspore::lite { void AnfImporterFromMetaGraph::ConverterConstTensor() { - MS_EXCEPTION_IF_NULL(model); - auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(model_); + auto *meta_graph = model_->GetMetaGraph(); MS_EXCEPTION_IF_NULL(meta_graph); - for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) { + num_of_tensors_ = meta_graph->allTensors()->size(); + for (size_t i = 0; i < num_of_tensors_; i++) { auto *tensor = meta_graph->allTensors()->GetAs(i); MS_EXCEPTION_IF_NULL(tensor); - if (tensor->nodeType() != schema::NodeType_ValueNode) { + if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) { continue; } MS_ASSERT(tensor->dims() != nullptr); - auto parameter = model->add_parameter(); + auto parameter = model_->add_parameter(); std::vector shape; for (size_t j = 0; j < tensor->dims()->size(); ++j) { shape.push_back(tensor->dims()->data()[j]); } - auto type_id = static_cast(tensor->dataType()); + auto type_id = static_cast(tensor->dataType()); // todo: check error auto type_ptr = TypeIdToType(type_id); - auto abstract_tensor = std::make_shared(type_ptr, shape); - parameter->set_abstract(abstract_tensor); + auto abstractBase = std::make_shared(type_ptr, shape); + // XXX TODO copy format + parameter->set_abstract(abstractBase); + parameter->set_name(std::string("Parameter")); - ParamValueLitePtr param_value = std::make_shared(); - MS_EXCEPTION_IF_NULL(param_value); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - if (tensor->data() != nullptr) { - auto size = tensor->data()->size(); - char *tensor_data = new char[size](); - std::memcpy(tensor_data, tensor->data()->data(), size); - MS_EXCEPTION_IF_NULL(tensor_data); - param_value->set_tensor_addr(tensor_data); - param_value->set_tensor_size(size); + if (tensor->nodeType() == schema::NodeType_ValueNode) { + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (tensor->data() != nullptr) { + auto size = tensor->data()->size(); + char *tensor_data = new char[size](); + std::memcpy(tensor_data, tensor->data()->data(), size); + MS_EXCEPTION_IF_NULL(tensor_data); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + parameter->set_default_param(param_value); } - parameter->set_default_param(param_value); AddNode(i, parameter); + model_->AddAnfNode(i, parameter); } } int AnfImporterFromMetaGraph::ConverterCNode() { - MS_EXCEPTION_IF_NULL(model); - auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(model_); + auto *meta_graph = model_->GetMetaGraph(); MS_EXCEPTION_IF_NULL(meta_graph); - auto cNodes = meta_graph->nodes(); - for (size_t i = 0; i < cNodes->size(); i++) { - auto cNode = cNodes->GetAs(i); - MS_EXCEPTION_IF_NULL(cNode); - auto tensor_id = cNode->outputIndex()->data()[0]; - if (GetNode(tensor_id)) { - continue; - } - auto prim = std::make_shared(model->GetOp(cNode->name()->str())); + // Crate CNode -- Order of inputs is as follows + // First input should be the Primitive + // Then we have CNodes that contribute to this CNode + // Finally we Have the parameters + + // first itteration -- create CNode with primitive, create originator map + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared(model_->GetOp(cNode->name()->str())); if (prim == nullptr) { MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; return RET_ERROR; } auto value_node = NewValueNode(prim); - AddNode(tensor_id, value_node); - + // auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str()); + // value_node->set_fullname_with_scope(prim_name); std::vector op_inputs = {value_node}; + + auto cnode = model_->NewCNode(op_inputs); + auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i); + cnode->set_fullname_with_scope(node_name); + AddNode(num_of_tensors_ + i, cnode); + + for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { + int tensor_id = cNode->outputIndex()->data()[j]; + originator_[tensor_id] = cnode; + } + } + // second itteration -- fill in input CNodes and Parameters + // populate map + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + std::vector input; + std::vector output; + int tensor_id; + auto cNode = meta_graph->nodes()->GetAs(i); + MS_EXCEPTION_IF_NULL(cNode); + auto cnode = std::dynamic_pointer_cast(GetNode(num_of_tensors_ + i)); + + for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { + tensor_id = cNode->outputIndex()->data()[j]; + output.push_back(tensor_id); + } + MS_EXCEPTION_IF_NULL(cNode->inputIndex()); for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { - auto node = GetNode(*(cNode->inputIndex()->GetAs(j))); - if (nullptr == node) { - MS_LOG(ERROR) << "Can't find input node."; - return RET_ERROR; + tensor_id = cNode->inputIndex()->data()[j]; + input.push_back(tensor_id); + auto *tensor = meta_graph->allTensors()->GetAs(tensor_id); + MS_EXCEPTION_IF_NULL(tensor); + if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) { + cnode->add_input(originator_[tensor_id]); } - // todo: CheckInputNodeType, the first node should be op; - op_inputs.push_back(node); } - auto cnode = model->NewCNode(op_inputs); - auto node_name = std::string(cNode->name()->c_str()); - cnode->set_fullname_with_scope(node_name); - AddNode(tensor_id, cnode); + // finally add all the Parameters (which are ValueNodes) + for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { + tensor_id = cNode->inputIndex()->data()[j]; + auto *tensor = meta_graph->allTensors()->GetAs(tensor_id); + MS_EXCEPTION_IF_NULL(tensor); + if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) { + cnode->add_input(GetNode(tensor_id)); + } + } + + model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); } + return RET_OK; } void AnfImporterFromMetaGraph::AddReturnCNode() { - MS_EXCEPTION_IF_NULL(model); - auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(model_); + auto *meta_graph = model_->GetMetaGraph(); MS_EXCEPTION_IF_NULL(meta_graph); + std::vector input; + std::vector output; std::vector op_inputs; auto value_node = NewValueNode(prim::kPrimReturn); + // value_node->set_fullname_with_scope("Primitive"); op_inputs.push_back(value_node); - auto tensor_id = meta_graph->outputIndex()->data()[0]; - op_inputs.push_back(GetNode(tensor_id)); - auto cnode = model->NewCNode(op_inputs); + for (int i = 0; i < meta_graph->outputIndex()->size(); i++) { + auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]]; + if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode); + input.push_back(meta_graph->outputIndex()->data()[i]); + } + auto cnode = model_->NewCNode(op_inputs); cnode->set_fullname_with_scope("return"); - model->set_return(cnode); + model_->set_return(cnode); + model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); } -FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; } +FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h index fd34930f1cd..b8389f42d10 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ #include +#include #include "src/train/model_impl.h" #include "schema/model_generated.h" #include "src/common/anf_importer/anf_importer.h" @@ -25,7 +26,7 @@ namespace mindspore::lite { class AnfImporterFromMetaGraph : public AnfImporter { public: - explicit AnfImporterFromMetaGraph(std::shared_ptr model) : model(model) {} + explicit AnfImporterFromMetaGraph(std::shared_ptr model) : model_(model) {} ~AnfImporterFromMetaGraph() override = default; @@ -39,9 +40,10 @@ class AnfImporterFromMetaGraph : public AnfImporter { void AddReturnCNode() override; private: - std::shared_ptr model = nullptr; + std::shared_ptr model_ = nullptr; + std::map originator_; + int num_of_tensors_ = 0; }; } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ - diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 040068bf6ca..572fa0e9700 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -60,7 +60,7 @@ class LiteKernel { explicit LiteKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) - : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive), + : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) { this->in_kernel_.clear(); this->out_kernel_.clear(); @@ -136,7 +136,7 @@ class LiteKernel { std::vector outputs_; std::vector in_kernel_; std::vector out_kernel_; - bool train_mode; + bool train_mode = false; bool need_reinit = false; }; diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 0fecef68e44..a1176216ce4 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#ifdef SUPPORT_TRAIN -#include "src/train/model_impl.h" -#else +// #ifdef SUPPORT_TRAIN +// #include "src/train/model_impl.h" +// #else #include "src/model_impl.h" -#endif +// #endif #include "include/model.h" #include "utils/log_adapter.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 0fec5a929c0..ae23f6a736e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -10,6 +10,13 @@ file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc ) +if (SUPPORT_TRAIN) +file (GLOB TRAIN_KERNEL_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/nnacl/fp32_grad/*.cc + ) +endif() + if (PLATFORM_ARM64) # assembly file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s @@ -27,5 +34,5 @@ if (PLATFORM_ARM32) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) endif() -add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) +add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) add_subdirectory(nnacl) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc similarity index 97% rename from mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc index df1b0f03925..7ed31910cdc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/activation_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/activation_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" @@ -102,6 +102,8 @@ kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h similarity index 87% rename from mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h index c57713b9a02..56ddf0f5fc8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ #include #include "src/lite_kernel.h" @@ -48,4 +48,4 @@ class ActivationGradCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc index 4c3fddbcbfb..0639433ef4a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc @@ -16,9 +16,9 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" -#include "src/runtime/kernel/arm/fp32/arithmetic_grad.h" -#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h similarity index 94% rename from mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h index 3dcd5811ec5..0e919c7b7fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ #include #include "src/lite_kernel.h" @@ -88,4 +88,4 @@ class ArithmeticGradCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc index 0bc583343a9..e5e53b09747 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc @@ -15,7 +15,7 @@ */ #include -#include "src/runtime/kernel/arm/fp32/bias_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/bias_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h similarity index 100% rename from mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc similarity index 95% rename from mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index 79bd775e70a..752a8440e93 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -18,8 +18,8 @@ #include #include "schema/model_generated.h" #include "src/kernel_factory.h" -#include "src/runtime/kernel/arm/fp32/bngrad_input.h" -#include "src/runtime//kernel/arm/nnacl/batch_norm.h" +#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -54,10 +54,6 @@ int BNGradInputCPUKernel::Init() { int BNGradInputCPUKernel::ReSize() { return RET_OK; } -/* -according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm -*/ - int BNGradInputCPUKernel::Run() { // std::cout << "run succ" << std::endl; auto *input_x = inputs_.at(0); @@ -107,6 +103,8 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h similarity index 85% rename from mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h index 182257d5a7e..6476ceddbb4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ #include #include "src/lite_kernel.h" @@ -39,4 +39,4 @@ class BNGradInputCPUKernel : public LiteKernel { int workspace_size; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc similarity index 96% rename from mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index 20e224a1292..844062e3243 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h" +#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/nnacl/pack.h" -#include "src/runtime/kernel/arm/nnacl/pack_ext.h" -#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h similarity index 84% rename from mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h index 20ce826c02f..7a9354be7e0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ #include #include "src/lite_kernel.h" @@ -39,4 +39,4 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc similarity index 96% rename from mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index bd9248137b2..8563565bf65 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/convolution_grad_input.h" +#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/nnacl/pack.h" -#include "src/runtime/kernel/arm/nnacl/pack_ext.h" -#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h similarity index 84% rename from mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h index c1297fef775..9653fe06ad3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ #include #include "src/lite_kernel.h" @@ -39,4 +39,4 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc index ab70d88bbfe..98f2e4143cd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc @@ -17,7 +17,7 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/fp32/opt_momentum.h" +#include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.h similarity index 100% rename from mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.h diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index a7dd2da9ba3..c98b8125154 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/pooling_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" -#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h similarity index 88% rename from mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h index f093062c1f7..32c20f0abd1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ #include #include "src/lite_kernel.h" @@ -48,4 +48,4 @@ class PoolingGradCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc similarity index 97% rename from mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc index 2523e3c5953..57209ab1263 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/power_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/power_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h similarity index 87% rename from mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h index 5361d992c44..737de8c2a0a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ #include #include "src/lite_kernel.h" @@ -47,4 +47,4 @@ class PowerGradCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc similarity index 93% rename from mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index e343b026a5b..c9d4706bdb9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -14,13 +14,12 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h" -#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" -#include "schema/model_generated.h" #include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" +#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" +#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" #include "include/errorcode.h" -using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; @@ -73,7 +72,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { auto ins = reinterpret_cast(inputs_.at(0)->Data()); auto labels = reinterpret_cast(inputs_.at(1)->Data()); - auto out = reinterpret_cast(outputs_.at(0)->Data()); + auto out = reinterpret_cast(outputs_.at(1)->Data()); float *grads = NULL; if (is_train()) { // outputs_.size() > 1) grads = reinterpret_cast(outputs_.at(0)->Data()); @@ -90,10 +89,11 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { SoftmaxParameter sm_params; sm_params.n_dim_ = param->n_dim_; sm_params.element_size_ = data_size; - sm_params.axis_ = 1; + sm_params.axis_ = 0; for (int i = 0; i < 4; i++) // softmax has only 4 params in shape sm_params.input_shape_[i] = param->input_shape_[i]; - float sum_data[sm_params.input_shape_[sm_params.axis_]]; + float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0}; + std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0); Softmax(ins, losses, sum_data, &sm_params); if (is_train()) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h similarity index 92% rename from mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index db9a0d12707..4447d293ba8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -20,7 +20,7 @@ #include #include "src/lite_kernel.h" #include "ir/anf.h" -#include "src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" namespace mindspore::kernel { @@ -30,8 +30,7 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, - const lite::Context *ctx, - const lite::Primitive *primitive) + const lite::Context *ctx, const lite::Primitive *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { param = reinterpret_cast(parameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/activation_grad.h new file mode 100644 index 00000000000..da9d2850f69 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/activation_grad.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ + +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +struct ActivationGradParameter { + OpParameter op_parameter{}; + int type_; + float alpha_{0.01}; +}; + +inline int ReluGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src1[i] > 0 ? 1.0f : 0.0f; + } + ElementMul(src0, dst, dst, length); + return OPCLIB_OK; +} + +inline int Relu6Grad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + if (src1[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; + } + } + ElementMul(src0, dst, dst, length); + return OPCLIB_OK; +} + +inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { + for (int i = 0; i < length; ++i) { + dst[i] = src1[i] > 0.0f ? 1.0f : alpha; + } + ElementMul(src0, dst, dst, length); + return OPCLIB_OK; +} + +inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return OPCLIB_OK; +} + +inline int TanhGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; + } + return OPCLIB_OK; +} + +inline int HSwishGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return OPCLIB_OK; +} + +inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return OPCLIB_OK; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.cc similarity index 93% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.cc index f13dc824f83..6727bf0e641 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h" void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { for (int i = 0; i < element_size; i++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h similarity index 84% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h index 9994b4d66d0..0b70402a923 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.cc index bbe4dc1f3c5..4e863c9a00d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.cc @@ -15,7 +15,7 @@ */ #include #include -#include "src/runtime/kernel/arm/nnacl/batch_norm.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" static void sumSpatialBatch(const float *in, int size, int ch, float *out) { std::fill(out, out + ch, 0.f); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h similarity index 89% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h index 0d9e8b74bff..b7a0a3b6913 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/batch_norm.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ -#define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ struct bnParameter { int batch; @@ -36,4 +36,5 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); -#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.cc index 83705a8fbdd..d62d0c6f239 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, int ldc) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h similarity index 78% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h index b3e30d09da7..69901bcf66b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gemm.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float beta, float *mat_c, int ldc); -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.cc similarity index 99% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.cc index 58a52963dd4..1ccaf52f9a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.cc @@ -15,7 +15,7 @@ */ #include -#include "src/runtime/kernel/arm/nnacl/pack_ext.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h similarity index 100% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/pack_ext.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.cc index 7c37fd38bc7..5511b822165 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include -#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { int stride_w = pooling_param->stride_w_; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h similarity index 78% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h index 0f6049afd46..1fabdfd9625 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.cc similarity index 98% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.cc index 40801c3f356..b90c16a30c4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.cc @@ -15,7 +15,7 @@ */ #include #include -#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" +#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h" static inline bool NextIndex(const int num_dims, const int *dims, int *current) { int carry = 1; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h similarity index 100% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h similarity index 100% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h rename to mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h diff --git a/mindspore/lite/src/train/base_ref_utils.cc b/mindspore/lite/src/train/base_ref_utils.cc index 5df16a95520..61a39d38cb0 100644 --- a/mindspore/lite/src/train/base_ref_utils.cc +++ b/mindspore/lite/src/train/base_ref_utils.cc @@ -57,4 +57,3 @@ std::vector>> TransformVectorRefTo return multiTensor; } } // namespace mindspore - diff --git a/mindspore/lite/src/train/base_ref_utils.h b/mindspore/lite/src/train/base_ref_utils.h index 63370efeb93..2d4620ead64 100644 --- a/mindspore/lite/src/train/base_ref_utils.h +++ b/mindspore/lite/src/train/base_ref_utils.h @@ -16,16 +16,15 @@ #include #include -#include "base/base_ref.h" +#include "utils/base_ref.h" #include "include/ms_tensor.h" -#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H -#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H +#ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ +#define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ namespace mindspore { std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref); std::vector>> TransformVectorRefToMultiTensor( const VectorRef &vector_ref); } // namespace mindspore -#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H - +#endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ diff --git a/mindspore/lite/src/train/lite_kernel_runtime.cc b/mindspore/lite/src/train/lite_kernel_runtime.cc index e1eb95d4420..656e8b3bb7d 100644 --- a/mindspore/lite/src/train/lite_kernel_runtime.cc +++ b/mindspore/lite/src/train/lite_kernel_runtime.cc @@ -14,7 +14,8 @@ * limitations under the License. */ -#include "mindspore/lite/src/train/lite_kernel_runtime.h" +#include "src/train/lite_kernel_runtime.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore::lite { std::vector LiteInferKernelRuntime::GetGraphInputs(const std::vector &execution_order) { std::vector graph_inputs; @@ -34,7 +35,8 @@ std::vector LiteInferKernelRuntime::GetGraphInputs(const std::vector &inputs, VectorRef *outputs) { + const std::vector &inputs, + std::vector *outputs) { MS_EXCEPTION_IF_NULL(graph); auto execution_order = graph->execution_order(); auto graph_inputs = GetGraphInputs(execution_order); @@ -56,15 +58,17 @@ void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(return_input)); auto output_tensors = liteKernel->GetOutputs(); for (auto output_tensor : output_tensors) { - tensor::TensorPtr output_tensor_ptr(output_tensor); - outputs->push_back(output_tensor_ptr); + // tensor::TensorPtr output_tensor_ptr(output_tensor); + outputs->push_back(output_tensor); } } } } -bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { +bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector &inputs, + std::vector *outputs) { MS_EXCEPTION_IF_NULL(graph); + BindInputOutput(graph, inputs, *outputs); std::vector kernels; auto nodes = graph->execution_order(); for (const auto &node : nodes) { @@ -76,8 +80,7 @@ bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { } kernel::LiteKernelUtil::TopologicalSortKernels(kernels); Executor executor; - auto ret = executor.Run(kernels); + auto ret = executor.Run(inputs, *outputs, kernels); return 0 == ret; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/train/lite_kernel_runtime.h b/mindspore/lite/src/train/lite_kernel_runtime.h index 27b4ec867bd..c5ae2d04d97 100644 --- a/mindspore/lite/src/train/lite_kernel_runtime.h +++ b/mindspore/lite/src/train/lite_kernel_runtime.h @@ -23,35 +23,28 @@ #include #include "src/runtime/allocator.h" #include "src/executor.h" -#include "runtime/device/kernel_runtime.h" +// #include "runtime/device/kernel_runtime.h" #include "runtime/device/device_address.h" #include "src/lite_kernel.h" #include "backend/session/kernel_graph.h" namespace mindspore::lite { -class LiteInferKernelRuntime : public device::KernelRuntime { +class LiteInferKernelRuntime { public: LiteInferKernelRuntime() = default; - ~LiteInferKernelRuntime() override = default; + ~LiteInferKernelRuntime() = default; - bool Init() override { return true; } - - void BindInputOutput(const session::KernelGraph *graph, const std::vector &inputs, - VectorRef *outputs); - - bool Run(session::KernelGraph *graph); + bool Run(session::KernelGraph *graph, const std::vector &inputs, + std::vector *outputs); void AssignKernelAddress(session::KernelGraph *graph) {} protected: + void BindInputOutput(const session::KernelGraph *graph, const std::vector &inputs, + std::vector *outputs); + std::vector GetGraphInputs(const std::vector &execution_order); - bool SyncStream() override { return true; }; - device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override { - return nullptr; - }; }; } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ - diff --git a/mindspore/lite/src/train/model_impl.cc b/mindspore/lite/src/train/model_impl.cc index 30d60f77094..84794dfaaad 100644 --- a/mindspore/lite/src/train/model_impl.cc +++ b/mindspore/lite/src/train/model_impl.cc @@ -16,11 +16,34 @@ #include #include "src/train/model_impl.h" -#include "schema/model_generated.h" #include "ir/func_graph.h" +#include "schema/model_generated.h" +#include "src/common/anf_importer/import_from_meta_graph.h" namespace mindspore::lite::train { +std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) { + MS_EXCEPTION_IF_NULL(model_buf); + flatbuffers::Verifier verify((const uint8_t *)model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + // todo hangangqiang remove when copy primitive done + auto *inner_buf = new char[size]; + memcpy(inner_buf, model_buf, size); + auto meta_graph = schema::GetMetaGraph(inner_buf); + auto func_graph_model = std::make_shared(meta_graph); + auto ret = func_graph_model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + AnfImporterFromMetaGraph anfImporter(func_graph_model); + anfImporter.Import(); + return func_graph_model; +} + const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { auto iter = ops.find(name); if (iter == ops.end()) { @@ -98,6 +121,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::Nchw2Nhwc(const_cast(srcPrim)); case schema::PrimitiveType_Nhwc2Nchw: return new lite::Nhwc2Nchw(const_cast(srcPrim)); + case schema::PrimitiveType_MatMul: + return new lite::MatMul(const_cast(srcPrim)); default: break; } @@ -115,5 +140,6 @@ int ModelImpl::BuildOps() { auto srcPrim = cNode->primitive(); this->ops[name] = CopyPrimitive(srcPrim); } + return 0; } } // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/model_impl.h b/mindspore/lite/src/train/model_impl.h index 496fed2ac3c..d35956a855d 100644 --- a/mindspore/lite/src/train/model_impl.h +++ b/mindspore/lite/src/train/model_impl.h @@ -15,11 +15,12 @@ */ #ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ -#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H +#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ #include #include #include +#include #include "schema/model_generated.h" #include "src/ops/ops.h" #include "ir/func_graph.h" @@ -28,7 +29,7 @@ namespace mindspore::lite { namespace train { class ModelImpl : public FuncGraph { public: - static std::shared_ptr Import(const char *model_buf, size_t size); + static std::shared_ptr Import(const char *model_buf, size_t size); // { return NULL; }; ModelImpl() = default; explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} ~ModelImpl() override = default; @@ -37,16 +38,27 @@ class ModelImpl : public FuncGraph { void FreeMetaGraph(); int BuildOps(); + void AddCNodeInputOutput(std::string name, const std::vector &input, const std::vector &output) { + std::vector *tuple = new std::vector[2]; + tuple[0] = input; + tuple[1] = output; + connectivity_[name] = tuple; + } + std::vector *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; } + void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; } + AnfNodePtr GetAnfNode(int id) { return tensors_[id]; } + protected: lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); protected: const schema::MetaGraph *meta_graph = nullptr; + std::map tensors_; + std::map *> connectivity_; std::map ops; }; } // namespace train using ModelImpl = mindspore::lite::train::ModelImpl; } // namespace mindspore::lite -#endif // MINDSPORE_LITE_INCLUDE_MODEL_H - +#endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ diff --git a/mindspore/lite/src/train/train_anf_session.cc b/mindspore/lite/src/train/train_anf_session.cc new file mode 100644 index 00000000000..9e7fb5b506c --- /dev/null +++ b/mindspore/lite/src/train/train_anf_session.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/train/train_anf_session.h" +#include "include/context.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" +#include "mindspore/lite/src/train/train_session.h" +#include "mindspore/lite/src/kernel_factory.h" +#include "mindspore/lite/src/param_value_lite.h" +#include "common/utils.h" +#include "mindspore/lite/src/ops/ops.h" +#include "ir/anf.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "abstract/abstract_value.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "src/ir/primitive_value.h" +#include "src/train/model_impl.h" + +namespace mindspore { +namespace session { +static std::vector GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + auto shape = nodeAbstract->GetShapeTrack(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "Not a Shape"; + return {}; + } + auto dims = dyn_cast(shape)->shape(); + return dims; + } else { + MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; + return {}; + } +} + +static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode + } else { + MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; + return schema::Format_NHWC; + } +} + +static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); + } else { + MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; + return TypeId::kTypeUnknown; + } +} + +void TrainANFSession::Init(lite::Context *context) { + MS_EXCEPTION_IF_NULL(context); + this->context_ = std::make_shared(context->thread_num_, context->allocator, context->device_ctx_); +} + +lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { + lite::tensor::Tensor *out_tensor = tensors_[anf_node]; + if (out_tensor == NULL) { + out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), + GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); + tensors_[anf_node] = out_tensor; + } + return out_tensor; +} + +int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { + auto return_node = kernel_graph->get_return(); + auto node_list = TopoSort(return_node); + auto model_imp = std::dynamic_pointer_cast(func_graph_); + for (auto &node : node_list) { + if (!node->isa()) { + continue; + } + KernelRelation kernel_relation; + auto cnode = node->cast(); + kernel_relation.node_full_name = cnode->fullname_with_scope(); + kernel_relation.cnode = cnode; + std::vector *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); + if (cnode_io_indices == NULL) { + MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); + } else { + for (int i = 0; i < cnode_io_indices[1].size(); i++) { + AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); + kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); + } + } + lite::tensor::Tensor *tensor_ptr = nullptr; + for (size_t index = 1; index < cnode->inputs().size(); ++index) { + if (cnode->input(index)->isa()) { + auto input_cnode = cnode->input(index)->cast(); + auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; + // todo not support multi-outputs kernel sudo as spilt + tensor_ptr = input_kernel_relation.output_tensor.front(); + } else if (cnode->input(index)->isa()) { + auto input_parameter = cnode->input(index)->cast(); + auto para = input_parameter->default_param(); + auto param_value = std::dynamic_pointer_cast(para); + // auto dims = param_value->tensor_shape(); + // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); + tensor_ptr = GetTensorForAnfNode(cnode->input(index)); + if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { + tensor_ptr->SetData(param_value->tensor_addr()); + } + } else if (cnode->input(index)->isa()) { + auto input_valuenode = cnode->input(index)->cast(); + // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), + // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); + tensor_ptr = GetTensorForAnfNode(input_valuenode); + // todo(yankai) + } else { + MS_ASSERT(false); + } + kernel_relation.input_tensor.push_back(tensor_ptr); + } + kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; + } + return 0; +} + +GraphId TrainANFSession::graph_sum_ = 0; + +KernelGraphPtr TrainANFSession::NewKernelGraph() { + auto graph = std::make_shared(); + graph->set_graph_id(graph_sum_); + graphs_[graph_sum_++] = graph; + return graph; +} + +std::shared_ptr TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto graph = NewKernelGraph(); + graph->set_return(func_graph->get_return()); + auto node_list = TopoSort(func_graph->get_return()); + std::vector cnode_order; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cn_node = node->cast(); + cnode_order.push_back(cn_node); + } + } + graph->set_execution_order(cnode_order); + return graph; +} +GraphId TrainANFSession::CompileGraph(NotNull func_graph) { + auto graph = ConstructKernelGraph(func_graph); + func_graph_ = func_graph; + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Set kernel info"; + SetKernelInfo(graph.get()); + + (void)BuildKernelInputAndOutputFromFuncGraph(graph); + MS_LOG(INFO) << "Build kernel"; + auto ret = BuildKernel(graph.get()); + if (0 != ret) { + MS_LOG(EXCEPTION) << "BuildKernel failed"; + } + + // return the graph id to backend + auto graph_id = graph->graph_id(); + graphs_[graph_id] = graph; + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + +void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector *outputs) { + auto &kernel_graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run + // auto execution_order = kernel_graph->execution_order(); + // Todo : hangangqiang + // Reorder(&execution_order); + // kernel_graph->set_execution_order(execution_order); + MS_LOG(INFO) << "Run graph start"; + auto ret = runtime_.Run(kernel_graph.get(), (std::vector &)inputs, *outputs); + if (!ret) { + MS_LOG(EXCEPTION) << "Run graph failed"; + } + MS_LOG(INFO) << "Run graph end"; +} + +void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_info = std::make_shared(); + kernel_node->set_kernel_info(kernel_info); + } +} + +int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { + std::string kernel_name = iter->first; + KernelRelation anf_register = iter->second; + MS_EXCEPTION_IF_NULL(anf_register.cnode); + if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { + continue; + } + auto value_node_prim = anf_register.cnode->input(0); + MS_EXCEPTION_IF_NULL(value_node_prim); + auto prim = GetValueNode>(value_node_prim); + MS_EXCEPTION_IF_NULL(prim); + auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); + MS_EXCEPTION_IF_NULL(node_primitive); + auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); + if (0 != ret) { + MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; + return ret; + } + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; + + auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, + node_primitive, context_.get(), desc); + if (nullptr == kernel) { + MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; + return -1; + } + std::shared_ptr kernel_mod(kernel); + kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); + + // kernel->train(); + auto kernel_info = dynamic_cast(anf_register.cnode->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method + } + return 0; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/lite/src/train/train_anf_session.h b/mindspore/lite/src/train/train_anf_session.h new file mode 100644 index 00000000000..fc495fd49e6 --- /dev/null +++ b/mindspore/lite/src/train/train_anf_session.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#include +#include +#include +#include +#include +#include "include/context.h" +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "mindspore/lite/src/train/lite_kernel_runtime.h" +// #include "backend/session/session_factory.h" +namespace mindspore { +namespace lite::tensor { +class Tensor; +} +namespace session { +struct KernelRelation { + std::string node_full_name; + std::vector input_tensor; + std::vector output_tensor; + CNodePtr cnode; +}; + +class TrainANFSession { + public: + explicit TrainANFSession(lite::Context *context) { Init(context); } + ~TrainANFSession() = default; + + GraphId CompileGraph(NotNull func_graph); + + void RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector *outputs); + + // void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override + // {}; + protected: + void Init(lite::Context *context); + std::shared_ptr context_ = nullptr; + std::unordered_map> graphs_; + static GraphId graph_sum_; + KernelGraphPtr NewKernelGraph(); + + private: + // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + // GraphId CompileGraph(const char *model_buf, size_t size); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); + int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); + + lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); + + void SetKernelInfo(const KernelGraph *kernel_graph); + int BuildKernel(const KernelGraph *kernel_graph); + lite::LiteInferKernelRuntime runtime_; + std::map kernel_relation_infos_; + FuncGraphPtr func_graph_ = NULL; + std::map tensors_; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index c92473e177e..d9d8ddc346c 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -15,6 +15,8 @@ */ #include +#include "include/context.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" #include "mindspore/lite/src/train/train_session.h" #include "mindspore/lite/src/kernel_factory.h" #include "mindspore/lite/src/param_value_lite.h" @@ -25,6 +27,7 @@ #include "abstract/abstract_value.h" #include "backend/session/anf_runtime_algorithm.h" #include "src/ir/primitive_value.h" +#include "src/train/model_impl.h" namespace mindspore { namespace session { @@ -57,16 +60,32 @@ static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { auto nodeAbstract = anfNodePtr->abstract(); if (nodeAbstract != nullptr) { - return nodeAbstract->GetTypeTrack()->type_id(); + return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); } else { MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; return TypeId::kTypeUnknown; } } +void TrainSession::Init(lite::Context *context) { + MS_EXCEPTION_IF_NULL(context); + this->context_ = std::make_shared(context->thread_num_, context->allocator, context->device_ctx_); +} + +lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { + lite::tensor::Tensor *out_tensor = tensors_[anf_node]; + if (out_tensor == NULL) { + out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), + GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); + tensors_[anf_node] = out_tensor; + } + return out_tensor; +} + int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { auto return_node = kernel_graph->get_return(); auto node_list = TopoSort(return_node); + auto model_imp = std::dynamic_pointer_cast(func_graph_); for (auto &node : node_list) { if (!node->isa()) { continue; @@ -75,11 +94,16 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k auto cnode = node->cast(); kernel_relation.node_full_name = cnode->fullname_with_scope(); kernel_relation.cnode = cnode; - auto *out_tensor = - new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode), - schema::NodeType_Parameter); - kernel_relation.output_tensor.push_back(out_tensor); - tensor::Tensor *tensor_ptr = nullptr; + std::vector *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); + if (cnode_io_indices == NULL) { + MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); + } else { + for (int i = 0; i < cnode_io_indices[1].size(); i++) { + AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); + kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); + } + } + lite::tensor::Tensor *tensor_ptr = nullptr; for (size_t index = 1; index < cnode->inputs().size(); ++index) { if (cnode->input(index)->isa()) { auto input_cnode = cnode->input(index)->cast(); @@ -90,17 +114,17 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k auto input_parameter = cnode->input(index)->cast(); auto para = input_parameter->default_param(); auto param_value = std::dynamic_pointer_cast(para); - auto dims = param_value->tensor_shape(); - tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC, - schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode - if (param_value->tensor_size() != 0) { + // auto dims = param_value->tensor_shape(); + // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); + tensor_ptr = GetTensorForAnfNode(cnode->input(index)); + if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { tensor_ptr->SetData(param_value->tensor_addr()); } } else if (cnode->input(index)->isa()) { auto input_valuenode = cnode->input(index)->cast(); - tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode), - schema::Format_NHWC, - schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode + // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), + // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); + tensor_ptr = GetTensorForAnfNode(input_valuenode); // todo(yankai) } else { MS_ASSERT(false); @@ -111,7 +135,7 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k } return 0; } - +#if 0 GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { auto graph_id = graph_sum_; auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); @@ -124,6 +148,17 @@ GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrLi } GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } +#else +GraphId TrainSession::graph_sum_ = 0; + +KernelGraphPtr TrainSession::NewKernelGraph() { + auto graph = std::make_shared(); + graph->set_graph_id(graph_sum_); + graphs_[graph_sum_++] = graph; + return graph; +} + +#endif std::shared_ptr TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); @@ -141,14 +176,14 @@ std::shared_ptr TrainSession::ConstructKernelGraph(const FuncGraphP graph->set_execution_order(cnode_order); return graph; } - GraphId TrainSession::CompileGraph(NotNull func_graph) { auto graph = ConstructKernelGraph(func_graph); + func_graph_ = func_graph; MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Set kernel info"; SetKernelInfo(graph.get()); - (void) BuildKernelInputAndOutputFromFuncGraph(graph); + (void)BuildKernelInputAndOutputFromFuncGraph(graph); MS_LOG(INFO) << "Build kernel"; auto ret = BuildKernel(graph.get()); if (0 != ret) { @@ -162,18 +197,18 @@ GraphId TrainSession::CompileGraph(NotNull func_graph) { return graph_id; } -void TrainSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, - std::vector &outputs) { +void TrainSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector *outputs) { auto &kernel_graph = graphs_[graph_id]; MS_EXCEPTION_IF_NULL(kernel_graph); MS_LOG(INFO) << "Bind input output address"; - runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); + // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run // auto execution_order = kernel_graph->execution_order(); // Todo : hangangqiang // Reorder(&execution_order); // kernel_graph->set_execution_order(execution_order); MS_LOG(INFO) << "Run graph start"; - auto ret = runtime_.Run(kernel_graph.get(), (std::vector &) inputs, outputs); + auto ret = runtime_.Run(kernel_graph.get(), (std::vector &)inputs, outputs); if (!ret) { MS_LOG(EXCEPTION) << "Run graph failed"; } @@ -199,34 +234,34 @@ int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { continue; } - lite::Context context; - context.deviceCtx.type = lite::DeviceType::DT_CPU; auto value_node_prim = anf_register.cnode->input(0); MS_EXCEPTION_IF_NULL(value_node_prim); auto prim = GetValueNode>(value_node_prim); MS_EXCEPTION_IF_NULL(prim); - auto node_primitive = (lite::Primitive *) (prim->GetPrimitive()); + auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); MS_EXCEPTION_IF_NULL(node_primitive); auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); if (0 != ret) { MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; return ret; } - kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()}; + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, - node_primitive, &context, desc); + node_primitive, context_.get(), desc); if (nullptr == kernel) { MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; return -1; } - kernel->train(); - auto *kernel_info = anf_register.cnode->kernel_info(); std::shared_ptr kernel_mod(kernel); - kernel_info->set_kernel_mod(kernel_mod); + kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); + + // kernel->train(); + auto kernel_info = dynamic_cast(anf_register.cnode->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method } return 0; } } // namespace session } // namespace mindspore - diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index d9b026d55f7..ff712ffd83b 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -19,47 +19,58 @@ #include #include #include +#include +#include "include/context.h" #include "backend/session/session_basic.h" #include "backend/session/kernel_graph.h" #include "mindspore/lite/src/train/lite_kernel_runtime.h" -#include "backend/session/session_factory.h" +// #include "backend/session/session_factory.h" namespace mindspore { namespace lite::tensor { class Tensor; } namespace session { struct KernelRelation { - std::string node_full_name; - std::vector input_tensor; - std::vector output_tensor; - CNodePtr cnode; + std::string node_full_name; + std::vector input_tensor; + std::vector output_tensor; + CNodePtr cnode; }; -class TrainSession : public SessionBasic { +class TrainSession { public: - TrainSession() : SessionBasic() {} - ~TrainSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kCPUDevice, device_id); - } + explicit TrainSession(lite::Context * context) { Init(context); } + ~TrainSession() = default; - GraphId CompileGraph(NotNull func_graph) override; + GraphId CompileGraph(NotNull func_graph); - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector *outputs); + + // void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override + // {}; + protected: + void Init(lite::Context *context); + std::shared_ptr context_ = nullptr; + std::unordered_map> graphs_; + static GraphId graph_sum_; + KernelGraphPtr NewKernelGraph(); private: - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - GraphId CompileGraph(const char *model_buf, size_t size); + // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + // GraphId CompileGraph(const char *model_buf, size_t size); std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); + + lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); + void SetKernelInfo(const KernelGraph *kernel_graph); int BuildKernel(const KernelGraph *kernel_graph); lite::LiteInferKernelRuntime runtime_; std::map kernel_relation_infos_; + FuncGraphPtr func_graph_ = NULL; + std::map tensors_; }; -MS_REG_SESSION(kCPUDevice, TrainSession); } // namespace session } // namespace mindspore #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ - diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 868b7ed6a58..2ff4e9f76a2 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -87,6 +87,16 @@ file(GLOB KERNEL_OP_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc ${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc ) + +file(GLOB KERNEL_OP_TRAIN_SRC + ${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp32_grad/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc +) + +if (SUPPORT_TRAIN) + list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) +endif() + if (PLATFORM_ARM64) # assembly file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s @@ -245,12 +255,13 @@ if (SUPPORT_TRAIN) # ${SRC_DIR}/device/kernel_info.cc # ${SRC_DIR}/device/kernel_runtime.cc # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc - ${LITE_DIR}/src/common/anf_importer/anf_importer.cc - ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc - ${LITE_DIR}/src/ir/primitive_value.cc - ${LITE_DIR}/src/train/lite_kernel_runtime.cc - ${LITE_DIR}/src/train/train_session.cc - ${LITE_DIR}/src/train/model_impl.cc + # ${LITE_DIR}/src/common/anf_importer/anf_importer.cc + # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc + # ${LITE_DIR}/src/ir/primitive_value.cc + # ${LITE_DIR}/src/train/lite_kernel_runtime.cc + # ${LITE_DIR}/src/train/train_session.cc + # ${LITE_DIR}/src/train/model_impl.cc + ${LITE_DIR}/src/lite_session.cc # temporary ) else() set(TEST_LITE_SRC @@ -265,6 +276,10 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc ) +file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc +) + set(TEST_SRC ${TEST_LITE_SRC} ${TEST_MINDDATA_SRC} @@ -278,7 +293,9 @@ set(TEST_SRC if (SUPPORT_TRAIN) set(TEST_SRC ${TEST_SRC} - ${TEST_DIR}/ut/src/train_test.cc + ${TEST_CASE_KERNEL_TRAIN_SRC} + # ${TEST_DIR}/ut/src/train_test.cc + ${TEST_DIR}/ut/src/infer_test.cc # temporary ) else() set(TEST_SRC diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc similarity index 99% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc index 1badd29a267..e1fd748bb32 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include #include @@ -25,7 +24,7 @@ #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/ir/tensor.h" #include "mindspore/lite/src/lite_kernel.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h" namespace mindspore { class TestActGradFp32 : public mindspore::Common { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc similarity index 99% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc index 50036732851..7ca2daf26ac 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc @@ -21,7 +21,7 @@ #include "src/common/file_utils.h" #include "src/common/file_utils_ext.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" #include "mindspore/lite/src/kernel_registry.h" namespace mindspore { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc similarity index 97% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc index 0a68d654dee..8146950ab60 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc @@ -18,7 +18,7 @@ #include "utils/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h" #include "mindspore/lite/src/kernel_registry.h" namespace mindspore { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc similarity index 99% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index 721486e6a63..1d114587ec9 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -21,8 +21,8 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/common/file_utils_ext.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc similarity index 98% rename from mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc rename to mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index 682f34c6a44..5194cab5acc 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -22,8 +22,8 @@ #include "mindspore/lite/src/kernel_registry.h" #include "src/common/utils.h" #include "src/common/file_utils.h" -#include "src/runtime/kernel/arm/fp32/pooling_grad.h" -#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" +#include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" +#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" namespace mindspore { class TestPoolingGradFp32 : public mindspore::Common { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc new file mode 100644 index 00000000000..c35bac8fa48 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" +#include "src/kernel_registry.h" + +namespace mindspore { + +class TestSoftmaxCrossEntropyFp32 : public mindspore::Common { + public: + TestSoftmaxCrossEntropyFp32() {} +}; + +TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { + // prepare stage + SoftmaxCrossEntropyParameter *sce_param = new SoftmaxCrossEntropyParameter(); + size_t input_size; + + std::string input_path = "./test_data/operators/sce_fp32_1_y_6_4.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_y({6, 4}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(input_data); + + std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; + auto ll_labels = reinterpret_cast(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); + auto labels = new int[6]; + for (int i = 0; i < 6; i++) labels[i] = static_cast(ll_labels[i]); + + std::vector dim_l({6}); + lite::tensor::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); + l_tensor.SetData(labels); + + std::vector inputs = {&y_tensor, &l_tensor}; + + auto loss = new float[1]; + std::vector dim_dw({1}); + lite::tensor::Tensor loss_tensor(TypeId::kNumberTypeFloat32, dim_dw); + loss_tensor.SetData(loss); + auto grad = new float[24]; + lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y); + grad_tensor.SetData(grad); + std::vector outputs = {&grad_tensor, &loss_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(sce_param), NULL, desc, nullptr); + kernel_obj->Run(); + + printf("==================total loss=================\n"); + std::cout << loss[0] << " ," << std::endl; + + printf("==================Testing Grad===============\n"); + + std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin"; + lite::CompareOutput(loss, output_path); + + ((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train(); + kernel_obj->Run(); + + printf("==================output data=================\n"); + for (int i = 0; i < 12; i++) { + std::cout << grad[i] << " ,"; + } + std::cout << std::endl; + std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; + lite::CompareOutput(grad, grad_path); + + delete sce_param; + l_tensor.SetData(NULL); + y_tensor.SetData(NULL); + MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed"; +} + +} // namespace mindspore