move grad ops to seperate folders

This commit is contained in:
yoni baehr 2020-08-09 14:32:04 +03:00
parent b7ebe2be4b
commit 43d7d3af55
58 changed files with 901 additions and 233 deletions

View File

@ -109,7 +109,7 @@ checkopts()
ENABLE_GPU="off" ENABLE_GPU="off"
# Process the options # Process the options
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:EnT:' opt
do do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in case "${opt}" in
@ -282,6 +282,11 @@ checkopts()
ENABLE_IBVERBS="on" ENABLE_IBVERBS="on"
echo "enable IBVERBS for parameter server" echo "enable IBVERBS for parameter server"
;; ;;
T)
check_on_off $OPTARG T
SUPPORT_TRAIN=$OPTARG
echo "support train on device "
;;
*) *)
echo "Unknown option ${opt}!" echo "Unknown option ${opt}!"
usage usage

View File

@ -23,6 +23,7 @@ endif()
if (SUPPORT_TRAIN) if (SUPPORT_TRAIN)
set(ANF_SRC set(ANF_SRC
${ANF_SRC}
# ${CCSRC_DIR}/common/trans.cc # ${CCSRC_DIR}/common/trans.cc
# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc # ${CCSRC_DIR}/utils/lite/base_ref_utils.cc
# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc # ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc
@ -40,14 +41,17 @@ if (SUPPORT_TRAIN)
set(LITE_SRC set(LITE_SRC
${LITE_SRC} ${LITE_SRC}
${ANF_SRC} ${ANF_SRC}
${PASS_SRC} # ${PASS_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc
${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc # ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc # ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc # ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc # ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary
) )
else () else ()
set(LITE_SRC set(LITE_SRC
${LITE_SRC} ${LITE_SRC}

View File

@ -27,26 +27,30 @@
namespace mindspore::lite { namespace mindspore::lite {
void AnfImporterFromMetaGraph::ConverterConstTensor() { void AnfImporterFromMetaGraph::ConverterConstTensor() {
MS_EXCEPTION_IF_NULL(model); MS_EXCEPTION_IF_NULL(model_);
auto *meta_graph = model->GetMetaGraph(); auto *meta_graph = model_->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph); MS_EXCEPTION_IF_NULL(meta_graph);
for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) { num_of_tensors_ = meta_graph->allTensors()->size();
for (size_t i = 0; i < num_of_tensors_; i++) {
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i); auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
if (tensor->nodeType() != schema::NodeType_ValueNode) { if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) {
continue; continue;
} }
MS_ASSERT(tensor->dims() != nullptr); MS_ASSERT(tensor->dims() != nullptr);
auto parameter = model->add_parameter(); auto parameter = model_->add_parameter();
std::vector<int> shape; std::vector<int> shape;
for (size_t j = 0; j < tensor->dims()->size(); ++j) { for (size_t j = 0; j < tensor->dims()->size(); ++j) {
shape.push_back(tensor->dims()->data()[j]); shape.push_back(tensor->dims()->data()[j]);
} }
auto type_id = static_cast<TypeId>(tensor->dataType()); auto type_id = static_cast<TypeId>(tensor->dataType()); // todo: check error
auto type_ptr = TypeIdToType(type_id); auto type_ptr = TypeIdToType(type_id);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); auto abstractBase = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor); // XXX TODO copy format
parameter->set_abstract(abstractBase);
parameter->set_name(std::string("Parameter"));
if (tensor->nodeType() == schema::NodeType_ValueNode) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_EXCEPTION_IF_NULL(param_value); MS_EXCEPTION_IF_NULL(param_value);
param_value->set_tensor_shape(shape); param_value->set_tensor_shape(shape);
@ -60,63 +64,106 @@ void AnfImporterFromMetaGraph::ConverterConstTensor() {
param_value->set_tensor_size(size); param_value->set_tensor_size(size);
} }
parameter->set_default_param(param_value); parameter->set_default_param(param_value);
}
AddNode(i, parameter); AddNode(i, parameter);
model_->AddAnfNode(i, parameter);
} }
} }
int AnfImporterFromMetaGraph::ConverterCNode() { int AnfImporterFromMetaGraph::ConverterCNode() {
MS_EXCEPTION_IF_NULL(model); MS_EXCEPTION_IF_NULL(model_);
auto *meta_graph = model->GetMetaGraph(); auto *meta_graph = model_->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph); MS_EXCEPTION_IF_NULL(meta_graph);
auto cNodes = meta_graph->nodes();
for (size_t i = 0; i < cNodes->size(); i++) {
auto cNode = cNodes->GetAs<schema::CNode>(i);
MS_EXCEPTION_IF_NULL(cNode);
auto tensor_id = cNode->outputIndex()->data()[0];
if (GetNode(tensor_id)) {
continue;
}
auto prim = std::make_shared<PrimitiveValue>(model->GetOp(cNode->name()->str())); // Crate CNode -- Order of inputs is as follows
// First input should be the Primitive
// Then we have CNodes that contribute to this CNode
// Finally we Have the parameters
// first itteration -- create CNode with primitive, create originator map
for (size_t i = 0; i < meta_graph->nodes()->size(); i++) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
MS_EXCEPTION_IF_NULL(cNode);
auto prim = std::make_shared<PrimitiveValue>(model_->GetOp(cNode->name()->str()));
if (prim == nullptr) { if (prim == nullptr) {
MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr";
return RET_ERROR; return RET_ERROR;
} }
auto value_node = NewValueNode(prim); auto value_node = NewValueNode(prim);
AddNode(tensor_id, value_node); // auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str());
// value_node->set_fullname_with_scope(prim_name);
std::vector<AnfNodePtr> op_inputs = {value_node}; std::vector<AnfNodePtr> op_inputs = {value_node};
auto cnode = model_->NewCNode(op_inputs);
auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i);
cnode->set_fullname_with_scope(node_name);
AddNode(num_of_tensors_ + i, cnode);
for (size_t j = 0; j < cNode->outputIndex()->size(); j++) {
int tensor_id = cNode->outputIndex()->data()[j];
originator_[tensor_id] = cnode;
}
}
// second itteration -- fill in input CNodes and Parameters
// populate map
for (size_t i = 0; i < meta_graph->nodes()->size(); i++) {
std::vector<int> input;
std::vector<int> output;
int tensor_id;
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
MS_EXCEPTION_IF_NULL(cNode);
auto cnode = std::dynamic_pointer_cast<CNode>(GetNode(num_of_tensors_ + i));
for (size_t j = 0; j < cNode->outputIndex()->size(); j++) {
tensor_id = cNode->outputIndex()->data()[j];
output.push_back(tensor_id);
}
MS_EXCEPTION_IF_NULL(cNode->inputIndex()); MS_EXCEPTION_IF_NULL(cNode->inputIndex());
for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { for (size_t j = 0; j < cNode->inputIndex()->size(); j++) {
auto node = GetNode(*(cNode->inputIndex()->GetAs<uint32_t>(j))); tensor_id = cNode->inputIndex()->data()[j];
if (nullptr == node) { input.push_back(tensor_id);
MS_LOG(ERROR) << "Can't find input node."; auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id);
return RET_ERROR; MS_EXCEPTION_IF_NULL(tensor);
if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) {
cnode->add_input(originator_[tensor_id]);
} }
// todo: CheckInputNodeType, the first node should be op;
op_inputs.push_back(node);
} }
auto cnode = model->NewCNode(op_inputs); // finally add all the Parameters (which are ValueNodes)
auto node_name = std::string(cNode->name()->c_str()); for (size_t j = 0; j < cNode->inputIndex()->size(); j++) {
cnode->set_fullname_with_scope(node_name); tensor_id = cNode->inputIndex()->data()[j];
AddNode(tensor_id, cnode); auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id);
MS_EXCEPTION_IF_NULL(tensor);
if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) {
cnode->add_input(GetNode(tensor_id));
} }
}
model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output);
}
return RET_OK; return RET_OK;
} }
void AnfImporterFromMetaGraph::AddReturnCNode() { void AnfImporterFromMetaGraph::AddReturnCNode() {
MS_EXCEPTION_IF_NULL(model); MS_EXCEPTION_IF_NULL(model_);
auto *meta_graph = model->GetMetaGraph(); auto *meta_graph = model_->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph); MS_EXCEPTION_IF_NULL(meta_graph);
std::vector<int> input;
std::vector<int> output;
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
auto value_node = NewValueNode(prim::kPrimReturn); auto value_node = NewValueNode(prim::kPrimReturn);
// value_node->set_fullname_with_scope("Primitive");
op_inputs.push_back(value_node); op_inputs.push_back(value_node);
auto tensor_id = meta_graph->outputIndex()->data()[0]; for (int i = 0; i < meta_graph->outputIndex()->size(); i++) {
op_inputs.push_back(GetNode(tensor_id)); auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]];
auto cnode = model->NewCNode(op_inputs); if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode);
input.push_back(meta_graph->outputIndex()->data()[i]);
}
auto cnode = model_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return"); cnode->set_fullname_with_scope("return");
model->set_return(cnode); model_->set_return(cnode);
model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output);
} }
FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; } FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_
#include <memory> #include <memory>
#include <map>
#include "src/train/model_impl.h" #include "src/train/model_impl.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/common/anf_importer/anf_importer.h" #include "src/common/anf_importer/anf_importer.h"
@ -25,7 +26,7 @@
namespace mindspore::lite { namespace mindspore::lite {
class AnfImporterFromMetaGraph : public AnfImporter { class AnfImporterFromMetaGraph : public AnfImporter {
public: public:
explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model(model) {} explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model_(model) {}
~AnfImporterFromMetaGraph() override = default; ~AnfImporterFromMetaGraph() override = default;
@ -39,9 +40,10 @@ class AnfImporterFromMetaGraph : public AnfImporter {
void AddReturnCNode() override; void AddReturnCNode() override;
private: private:
std::shared_ptr<ModelImpl> model = nullptr; std::shared_ptr<ModelImpl> model_ = nullptr;
std::map<int, AnfNodePtr> originator_;
int num_of_tensors_ = 0;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_

View File

@ -60,7 +60,7 @@ class LiteKernel {
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive) const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive), : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
context_(ctx) { context_(ctx) {
this->in_kernel_.clear(); this->in_kernel_.clear();
this->out_kernel_.clear(); this->out_kernel_.clear();
@ -136,7 +136,7 @@ class LiteKernel {
std::vector<lite::tensor::Tensor *> outputs_; std::vector<lite::tensor::Tensor *> outputs_;
std::vector<LiteKernel *> in_kernel_; std::vector<LiteKernel *> in_kernel_;
std::vector<LiteKernel *> out_kernel_; std::vector<LiteKernel *> out_kernel_;
bool train_mode; bool train_mode = false;
bool need_reinit = false; bool need_reinit = false;
}; };

View File

@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#ifdef SUPPORT_TRAIN // #ifdef SUPPORT_TRAIN
#include "src/train/model_impl.h" // #include "src/train/model_impl.h"
#else // #else
#include "src/model_impl.h" #include "src/model_impl.h"
#endif // #endif
#include "include/model.h" #include "include/model.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"

View File

@ -10,6 +10,13 @@ file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
) )
if (SUPPORT_TRAIN)
file (GLOB TRAIN_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/nnacl/fp32_grad/*.cc
)
endif()
if (PLATFORM_ARM64) if (PLATFORM_ARM64)
# assembly # assembly
file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s
@ -27,5 +34,5 @@ if (PLATFORM_ARM32)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif() endif()
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})
add_subdirectory(nnacl) add_subdirectory(nnacl)

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/activation_grad.h" #include "src/runtime/kernel/arm/fp32_grad/activation_grad.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
@ -102,6 +102,8 @@ kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::t
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
} }
return kernel; return kernel;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -48,4 +48,4 @@ class ActivationGradCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_

View File

@ -16,9 +16,9 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h"
#include "src/runtime/kernel/arm/fp32/arithmetic_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h"
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" #include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -88,4 +88,4 @@ class ArithmeticGradCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_

View File

@ -15,7 +15,7 @@
*/ */
#include <vector> #include <vector>
#include "src/runtime/kernel/arm/fp32/bias_grad.h" #include "src/runtime/kernel/arm/fp32_grad/bias_grad.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"

View File

@ -18,8 +18,8 @@
#include <vector> #include <vector>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_factory.h" #include "src/kernel_factory.h"
#include "src/runtime/kernel/arm/fp32/bngrad_input.h" #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include "src/runtime//kernel/arm/nnacl/batch_norm.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -54,10 +54,6 @@ int BNGradInputCPUKernel::Init() {
int BNGradInputCPUKernel::ReSize() { return RET_OK; } int BNGradInputCPUKernel::ReSize() { return RET_OK; }
/*
according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm
*/
int BNGradInputCPUKernel::Run() { int BNGradInputCPUKernel::Run() {
// std::cout << "run succ" << std::endl; // std::cout << "run succ" << std::endl;
auto *input_x = inputs_.at(0); auto *input_x = inputs_.at(0);
@ -107,6 +103,8 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tens
if (RET_OK != ret) { if (RET_OK != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
} }
return kernel; return kernel;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -39,4 +39,4 @@ class BNGradInputCPUKernel : public LiteKernel {
int workspace_size; int workspace_size;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_

View File

@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h" #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/pack.h" #include "src/runtime/kernel/arm/nnacl/pack.h"
#include "src/runtime/kernel/arm/nnacl/pack_ext.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -39,4 +39,4 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_

View File

@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/convolution_grad_input.h" #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/pack.h" #include "src/runtime/kernel/arm/nnacl/pack.h"
#include "src/runtime/kernel/arm/nnacl/pack_ext.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -39,4 +39,4 @@ class ConvolutionGradInputCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H

View File

@ -17,7 +17,7 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32/opt_momentum.h" #include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;

View File

@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/pooling_grad.h" #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -48,4 +48,4 @@ class PoolingGradCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/power_grad.h" #include "src/runtime/kernel/arm/fp32_grad/power_grad.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -47,4 +47,4 @@ class PowerGradCPUKernel : public LiteKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_

View File

@ -14,13 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h"
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h"
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
@ -73,7 +72,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data()); auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data()); auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data());
auto out = reinterpret_cast<float *>(outputs_.at(0)->Data()); auto out = reinterpret_cast<float *>(outputs_.at(1)->Data());
float *grads = NULL; float *grads = NULL;
if (is_train()) { // outputs_.size() > 1) if (is_train()) { // outputs_.size() > 1)
grads = reinterpret_cast<float *>(outputs_.at(0)->Data()); grads = reinterpret_cast<float *>(outputs_.at(0)->Data());
@ -90,10 +89,11 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
SoftmaxParameter sm_params; SoftmaxParameter sm_params;
sm_params.n_dim_ = param->n_dim_; sm_params.n_dim_ = param->n_dim_;
sm_params.element_size_ = data_size; sm_params.element_size_ = data_size;
sm_params.axis_ = 1; sm_params.axis_ = 0;
for (int i = 0; i < 4; i++) // softmax has only 4 params in shape for (int i = 0; i < 4; i++) // softmax has only 4 params in shape
sm_params.input_shape_[i] = param->input_shape_[i]; sm_params.input_shape_[i] = param->input_shape_[i];
float sum_data[sm_params.input_shape_[sm_params.axis_]]; float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0};
std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0);
Softmax(ins, losses, sum_data, &sm_params); Softmax(ins, losses, sum_data, &sm_params);
if (is_train()) { if (is_train()) {

View File

@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h"
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h"
namespace mindspore::kernel { namespace mindspore::kernel {
@ -30,8 +30,7 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel {
explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter,
const std::vector<lite::tensor::Tensor *> &inputs, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
const lite::Context *ctx, const lite::Context *ctx, const lite::Primitive *primitive)
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) { : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter);
} }

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_
#include <math.h>
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
#include "src/runtime/kernel/arm/opclib/errorcode.h"
struct ActivationGradParameter {
OpParameter op_parameter{};
int type_;
float alpha_{0.01};
};
inline int ReluGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = src1[i] > 0 ? 1.0f : 0.0f;
}
ElementMul(src0, dst, dst, length);
return OPCLIB_OK;
}
inline int Relu6Grad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
if (src1[i] < 0) {
dst[i] = 0;
} else {
dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f;
}
}
ElementMul(src0, dst, dst, length);
return OPCLIB_OK;
}
inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) {
for (int i = 0; i < length; ++i) {
dst[i] = src1[i] > 0.0f ? 1.0f : alpha;
}
ElementMul(src0, dst, dst, length);
return OPCLIB_OK;
}
inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = src0[i] * (src1[i] * (1.0f - src1[i]));
}
return OPCLIB_OK;
}
inline int TanhGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i];
}
return OPCLIB_OK;
}
inline int HSwishGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f));
dst[i] = tmp * src0[i];
}
return OPCLIB_OK;
}
inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f));
dst[i] = tmp * src0[i];
}
return OPCLIB_OK;
}
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h"
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size);
void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size);

View File

@ -15,7 +15,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include "src/runtime/kernel/arm/nnacl/batch_norm.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h"
static void sumSpatialBatch(const float *in, int size, int ch, float *out) { static void sumSpatialBatch(const float *in, int size, int ch, float *out) {
std::fill(out, out + ch, 0.f); std::fill(out, out + ch, 0.f);

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_
#define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_
struct bnParameter { struct bnParameter {
int batch; int batch;
@ -36,4 +36,5 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta);
#endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c,
int ldc) { int ldc) {

View File

@ -14,10 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
int ldb, float beta, float *mat_c, int ldc); int ldb, float beta, float *mat_c, int ldc);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_

View File

@ -15,7 +15,7 @@
*/ */
#include <string.h> #include <string.h>
#include "src/runtime/kernel/arm/nnacl/pack_ext.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <cstdint> #include <cstdint>
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
int stride_w = pooling_param->stride_w_; int stride_w = pooling_param->stride_w_;

View File

@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_

View File

@ -15,7 +15,7 @@
*/ */
#include <cstddef> #include <cstddef>
#include <algorithm> #include <algorithm>
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h"
static inline bool NextIndex(const int num_dims, const int *dims, int *current) { static inline bool NextIndex(const int num_dims, const int *dims, int *current) {
int carry = 1; int carry = 1;

View File

@ -57,4 +57,3 @@ std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefTo
return multiTensor; return multiTensor;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -16,16 +16,15 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "base/base_ref.h" #include "utils/base_ref.h"
#include "include/ms_tensor.h" #include "include/ms_tensor.h"
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H #ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H #define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
namespace mindspore { namespace mindspore {
std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref); std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);
std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor( std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor(
const VectorRef &vector_ref); const VectorRef &vector_ref);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H #endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_

View File

@ -14,7 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include "mindspore/lite/src/train/lite_kernel_runtime.h" #include "src/train/lite_kernel_runtime.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::lite { namespace mindspore::lite {
std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) { std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) {
std::vector<CNodePtr> graph_inputs; std::vector<CNodePtr> graph_inputs;
@ -34,7 +35,8 @@ std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<C
} }
void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
auto graph_inputs = GetGraphInputs(execution_order); auto graph_inputs = GetGraphInputs(execution_order);
@ -56,15 +58,17 @@ void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input)); auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input));
auto output_tensors = liteKernel->GetOutputs(); auto output_tensors = liteKernel->GetOutputs();
for (auto output_tensor : output_tensors) { for (auto output_tensor : output_tensors) {
tensor::TensorPtr output_tensor_ptr(output_tensor); // tensor::TensorPtr output_tensor_ptr(output_tensor);
outputs->push_back(output_tensor_ptr); outputs->push_back(output_tensor);
} }
} }
} }
} }
bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
BindInputOutput(graph, inputs, *outputs);
std::vector<kernel::LiteKernel *> kernels; std::vector<kernel::LiteKernel *> kernels;
auto nodes = graph->execution_order(); auto nodes = graph->execution_order();
for (const auto &node : nodes) { for (const auto &node : nodes) {
@ -76,8 +80,7 @@ bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) {
} }
kernel::LiteKernelUtil::TopologicalSortKernels(kernels); kernel::LiteKernelUtil::TopologicalSortKernels(kernels);
Executor executor; Executor executor;
auto ret = executor.Run(kernels); auto ret = executor.Run(inputs, *outputs, kernels);
return 0 == ret; return 0 == ret;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -23,35 +23,28 @@
#include <unordered_map> #include <unordered_map>
#include "src/runtime/allocator.h" #include "src/runtime/allocator.h"
#include "src/executor.h" #include "src/executor.h"
#include "runtime/device/kernel_runtime.h" // #include "runtime/device/kernel_runtime.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
namespace mindspore::lite { namespace mindspore::lite {
class LiteInferKernelRuntime : public device::KernelRuntime { class LiteInferKernelRuntime {
public: public:
LiteInferKernelRuntime() = default; LiteInferKernelRuntime() = default;
~LiteInferKernelRuntime() override = default; ~LiteInferKernelRuntime() = default;
bool Init() override { return true; } bool Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs);
void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
bool Run(session::KernelGraph *graph);
void AssignKernelAddress(session::KernelGraph *graph) {} void AssignKernelAddress(session::KernelGraph *graph) {}
protected: protected:
void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs);
std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order); std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order);
bool SyncStream() override { return true; };
device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override {
return nullptr;
};
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ #endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_

View File

@ -16,11 +16,34 @@
#include <string> #include <string>
#include "src/train/model_impl.h" #include "src/train/model_impl.h"
#include "schema/model_generated.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "schema/model_generated.h"
#include "src/common/anf_importer/import_from_meta_graph.h"
namespace mindspore::lite::train { namespace mindspore::lite::train {
std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) {
MS_EXCEPTION_IF_NULL(model_buf);
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
// todo hangangqiang remove when copy primitive done
auto *inner_buf = new char[size];
memcpy(inner_buf, model_buf, size);
auto meta_graph = schema::GetMetaGraph(inner_buf);
auto func_graph_model = std::make_shared<ModelImpl>(meta_graph);
auto ret = func_graph_model->BuildOps();
if (0 != ret) {
MS_LOG(ERROR) << "BuildOps failed";
return nullptr;
}
AnfImporterFromMetaGraph anfImporter(func_graph_model);
anfImporter.Import();
return func_graph_model;
}
const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { const lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
auto iter = ops.find(name); auto iter = ops.find(name);
if (iter == ops.end()) { if (iter == ops.end()) {
@ -98,6 +121,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim)); return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Nhwc2Nchw: case schema::PrimitiveType_Nhwc2Nchw:
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim)); return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
default: default:
break; break;
} }
@ -115,5 +140,6 @@ int ModelImpl::BuildOps() {
auto srcPrim = cNode->primitive(); auto srcPrim = cNode->primitive();
this->ops[name] = CopyPrimitive(srcPrim); this->ops[name] = CopyPrimitive(srcPrim);
} }
return 0;
} }
} // namespace mindspore::lite::train } // namespace mindspore::lite::train

View File

@ -15,11 +15,12 @@
*/ */
#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ #ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H #define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
#include <string> #include <string>
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/ops/ops.h" #include "src/ops/ops.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
@ -28,7 +29,7 @@ namespace mindspore::lite {
namespace train { namespace train {
class ModelImpl : public FuncGraph { class ModelImpl : public FuncGraph {
public: public:
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); // { return NULL; };
ModelImpl() = default; ModelImpl() = default;
explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {}
~ModelImpl() override = default; ~ModelImpl() override = default;
@ -37,16 +38,27 @@ class ModelImpl : public FuncGraph {
void FreeMetaGraph(); void FreeMetaGraph();
int BuildOps(); int BuildOps();
void AddCNodeInputOutput(std::string name, const std::vector<int> &input, const std::vector<int> &output) {
std::vector<int> *tuple = new std::vector<int>[2];
tuple[0] = input;
tuple[1] = output;
connectivity_[name] = tuple;
}
std::vector<int> *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; }
void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; }
AnfNodePtr GetAnfNode(int id) { return tensors_[id]; }
protected: protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim);
protected: protected:
const schema::MetaGraph *meta_graph = nullptr; const schema::MetaGraph *meta_graph = nullptr;
std::map<int, AnfNodePtr> tensors_;
std::map<std::string, std::vector<int> *> connectivity_;
std::map<std::string, lite::Primitive *> ops; std::map<std::string, lite::Primitive *> ops;
}; };
} // namespace train } // namespace train
using ModelImpl = mindspore::lite::train::ModelImpl; using ModelImpl = mindspore::lite::train::ModelImpl;
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_INCLUDE_MODEL_H #endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_

View File

@ -0,0 +1,253 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "src/train/train_anf_session.h"
#include "include/context.h"
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
#include "mindspore/lite/src/train/train_session.h"
#include "mindspore/lite/src/kernel_factory.h"
#include "mindspore/lite/src/param_value_lite.h"
#include "common/utils.h"
#include "mindspore/lite/src/ops/ops.h"
#include "ir/anf.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "abstract/abstract_value.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "src/ir/primitive_value.h"
#include "src/train/model_impl.h"
namespace mindspore {
namespace session {
static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
auto shape = nodeAbstract->GetShapeTrack();
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Not a Shape";
return {};
}
auto dims = dyn_cast<abstract::Shape>(shape)->shape();
return dims;
} else {
MS_LOG(WARNING) << "abstract is nullptr, return empty dims";
return {};
}
}
static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode
} else {
MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC";
return schema::Format_NHWC;
}
}
static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
} else {
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
return TypeId::kTypeUnknown;
}
}
void TrainANFSession::Init(lite::Context *context) {
MS_EXCEPTION_IF_NULL(context);
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
}
lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
if (out_tensor == NULL) {
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
tensors_[anf_node] = out_tensor;
}
return out_tensor;
}
int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
auto return_node = kernel_graph->get_return();
auto node_list = TopoSort(return_node);
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
for (auto &node : node_list) {
if (!node->isa<CNode>()) {
continue;
}
KernelRelation kernel_relation;
auto cnode = node->cast<CNodePtr>();
kernel_relation.node_full_name = cnode->fullname_with_scope();
kernel_relation.cnode = cnode;
std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
if (cnode_io_indices == NULL) {
MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
} else {
for (int i = 0; i < cnode_io_indices[1].size(); i++) {
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
}
}
lite::tensor::Tensor *tensor_ptr = nullptr;
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
if (cnode->input(index)->isa<CNode>()) {
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()];
// todo not support multi-outputs kernel sudo as spilt
tensor_ptr = input_kernel_relation.output_tensor.front();
} else if (cnode->input(index)->isa<Parameter>()) {
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
auto para = input_parameter->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
// auto dims = param_value->tensor_shape();
// tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
tensor_ptr = GetTensorForAnfNode(cnode->input(index));
if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
tensor_ptr->SetData(param_value->tensor_addr());
}
} else if (cnode->input(index)->isa<ValueNode>()) {
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
// tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
// GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
tensor_ptr = GetTensorForAnfNode(input_valuenode);
// todo(yankai)
} else {
MS_ASSERT(false);
}
kernel_relation.input_tensor.push_back(tensor_ptr);
}
kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation;
}
return 0;
}
GraphId TrainANFSession::graph_sum_ = 0;
KernelGraphPtr TrainANFSession::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}
std::shared_ptr<KernelGraph> TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto graph = NewKernelGraph();
graph->set_return(func_graph->get_return());
auto node_list = TopoSort(func_graph->get_return());
std::vector<CNodePtr> cnode_order;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cn_node = node->cast<CNodePtr>();
cnode_order.push_back(cn_node);
}
}
graph->set_execution_order(cnode_order);
return graph;
}
GraphId TrainANFSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto graph = ConstructKernelGraph(func_graph);
func_graph_ = func_graph;
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());
(void)BuildKernelInputAndOutputFromFuncGraph(graph);
MS_LOG(INFO) << "Build kernel";
auto ret = BuildKernel(graph.get());
if (0 != ret) {
MS_LOG(EXCEPTION) << "BuildKernel failed";
}
// return the graph id to backend
auto graph_id = graph->graph_id();
graphs_[graph_id] = graph;
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
// runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
// auto execution_order = kernel_graph->execution_order();
// Todo : hangangqiang
// Reorder(&execution_order);
// kernel_graph->set_execution_order(execution_order);
MS_LOG(INFO) << "Run graph start";
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, *outputs);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
MS_LOG(INFO) << "Run graph end";
}
void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &kernel_nodes = kernel_graph->execution_order();
for (const auto &kernel_node : kernel_nodes) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_node->set_kernel_info(kernel_info);
}
}
int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) {
std::string kernel_name = iter->first;
KernelRelation anf_register = iter->second;
MS_EXCEPTION_IF_NULL(anf_register.cnode);
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
continue;
}
auto value_node_prim = anf_register.cnode->input(0);
MS_EXCEPTION_IF_NULL(value_node_prim);
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
MS_EXCEPTION_IF_NULL(prim);
auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
MS_EXCEPTION_IF_NULL(node_primitive);
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
if (0 != ret) {
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
return ret;
}
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};
auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
node_primitive, context_.get(), desc);
if (nullptr == kernel) {
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
return -1;
}
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
kernel_mod->set_name(anf_register.cnode->fullname_with_scope());
// kernel->train();
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
}
return 0;
}
} // namespace session
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#include <map>
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "include/context.h"
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
// #include "backend/session/session_factory.h"
namespace mindspore {
namespace lite::tensor {
class Tensor;
}
namespace session {
struct KernelRelation {
std::string node_full_name;
std::vector<lite::tensor::Tensor *> input_tensor;
std::vector<lite::tensor::Tensor *> output_tensor;
CNodePtr cnode;
};
class TrainANFSession {
public:
explicit TrainANFSession(lite::Context *context) { Init(context); }
~TrainANFSession() = default;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs);
// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
// {};
protected:
void Init(lite::Context *context);
std::shared_ptr<lite::Context> context_ = nullptr;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
static GraphId graph_sum_;
KernelGraphPtr NewKernelGraph();
private:
// GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
// GraphId CompileGraph(const char *model_buf, size_t size);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);
lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);
void SetKernelInfo(const KernelGraph *kernel_graph);
int BuildKernel(const KernelGraph *kernel_graph);
lite::LiteInferKernelRuntime runtime_;
std::map<std::string, KernelRelation> kernel_relation_infos_;
FuncGraphPtr func_graph_ = NULL;
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_

View File

@ -15,6 +15,8 @@
*/ */
#include <algorithm> #include <algorithm>
#include "include/context.h"
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
#include "mindspore/lite/src/train/train_session.h" #include "mindspore/lite/src/train/train_session.h"
#include "mindspore/lite/src/kernel_factory.h" #include "mindspore/lite/src/kernel_factory.h"
#include "mindspore/lite/src/param_value_lite.h" #include "mindspore/lite/src/param_value_lite.h"
@ -25,6 +27,7 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "src/ir/primitive_value.h" #include "src/ir/primitive_value.h"
#include "src/train/model_impl.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
@ -57,16 +60,32 @@ static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract(); auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) { if (nodeAbstract != nullptr) {
return nodeAbstract->GetTypeTrack()->type_id(); return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
} else { } else {
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
return TypeId::kTypeUnknown; return TypeId::kTypeUnknown;
} }
} }
void TrainSession::Init(lite::Context *context) {
MS_EXCEPTION_IF_NULL(context);
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
}
lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
if (out_tensor == NULL) {
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
tensors_[anf_node] = out_tensor;
}
return out_tensor;
}
int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
auto return_node = kernel_graph->get_return(); auto return_node = kernel_graph->get_return();
auto node_list = TopoSort(return_node); auto node_list = TopoSort(return_node);
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
for (auto &node : node_list) { for (auto &node : node_list) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
@ -75,11 +94,16 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
kernel_relation.node_full_name = cnode->fullname_with_scope(); kernel_relation.node_full_name = cnode->fullname_with_scope();
kernel_relation.cnode = cnode; kernel_relation.cnode = cnode;
auto *out_tensor = std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode), if (cnode_io_indices == NULL) {
schema::NodeType_Parameter); MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
kernel_relation.output_tensor.push_back(out_tensor); } else {
tensor::Tensor *tensor_ptr = nullptr; for (int i = 0; i < cnode_io_indices[1].size(); i++) {
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
}
}
lite::tensor::Tensor *tensor_ptr = nullptr;
for (size_t index = 1; index < cnode->inputs().size(); ++index) { for (size_t index = 1; index < cnode->inputs().size(); ++index) {
if (cnode->input(index)->isa<CNode>()) { if (cnode->input(index)->isa<CNode>()) {
auto input_cnode = cnode->input(index)->cast<CNodePtr>(); auto input_cnode = cnode->input(index)->cast<CNodePtr>();
@ -90,17 +114,17 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
auto para = input_parameter->default_param(); auto para = input_parameter->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
auto dims = param_value->tensor_shape(); // auto dims = param_value->tensor_shape();
tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC, // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode tensor_ptr = GetTensorForAnfNode(cnode->input(index));
if (param_value->tensor_size() != 0) { if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
tensor_ptr->SetData(param_value->tensor_addr()); tensor_ptr->SetData(param_value->tensor_addr());
} }
} else if (cnode->input(index)->isa<ValueNode>()) { } else if (cnode->input(index)->isa<ValueNode>()) {
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode), // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
schema::Format_NHWC, // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode tensor_ptr = GetTensorForAnfNode(input_valuenode);
// todo(yankai) // todo(yankai)
} else { } else {
MS_ASSERT(false); MS_ASSERT(false);
@ -111,7 +135,7 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
} }
return 0; return 0;
} }
#if 0
GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_; auto graph_id = graph_sum_;
auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); auto graph = SessionBasic::ConstructKernelGraph(lst, outputs);
@ -124,6 +148,17 @@ GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrLi
} }
GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; }
#else
GraphId TrainSession::graph_sum_ = 0;
KernelGraphPtr TrainSession::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}
#endif
std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -141,14 +176,14 @@ std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphP
graph->set_execution_order(cnode_order); graph->set_execution_order(cnode_order);
return graph; return graph;
} }
GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto graph = ConstructKernelGraph(func_graph); auto graph = ConstructKernelGraph(func_graph);
func_graph_ = func_graph;
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Set kernel info"; MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get()); SetKernelInfo(graph.get());
(void) BuildKernelInputAndOutputFromFuncGraph(graph); (void)BuildKernelInputAndOutputFromFuncGraph(graph);
MS_LOG(INFO) << "Build kernel"; MS_LOG(INFO) << "Build kernel";
auto ret = BuildKernel(graph.get()); auto ret = BuildKernel(graph.get());
if (0 != ret) { if (0 != ret) {
@ -162,18 +197,18 @@ GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
return graph_id; return graph_id;
} }
void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Tensor *> &inputs, void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> &outputs) { std::vector<lite::tensor::Tensor *> *outputs) {
auto &kernel_graph = graphs_[graph_id]; auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Bind input output address"; MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
// auto execution_order = kernel_graph->execution_order(); // auto execution_order = kernel_graph->execution_order();
// Todo : hangangqiang // Todo : hangangqiang
// Reorder(&execution_order); // Reorder(&execution_order);
// kernel_graph->set_execution_order(execution_order); // kernel_graph->set_execution_order(execution_order);
MS_LOG(INFO) << "Run graph start"; MS_LOG(INFO) << "Run graph start";
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<tensor::Tensor *> &) inputs, outputs); auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, outputs);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed"; MS_LOG(EXCEPTION) << "Run graph failed";
} }
@ -199,34 +234,34 @@ int TrainSession::BuildKernel(const KernelGraph *kernel_graph) {
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
continue; continue;
} }
lite::Context context;
context.deviceCtx.type = lite::DeviceType::DT_CPU;
auto value_node_prim = anf_register.cnode->input(0); auto value_node_prim = anf_register.cnode->input(0);
MS_EXCEPTION_IF_NULL(value_node_prim); MS_EXCEPTION_IF_NULL(value_node_prim);
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
auto node_primitive = (lite::Primitive *) (prim->GetPrimitive()); auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
MS_EXCEPTION_IF_NULL(node_primitive); MS_EXCEPTION_IF_NULL(node_primitive);
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
if (0 != ret) { if (0 != ret) {
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
return ret; return ret;
} }
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()}; kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};
auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
node_primitive, &context, desc); node_primitive, context_.get(), desc);
if (nullptr == kernel) { if (nullptr == kernel) {
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
return -1; return -1;
} }
kernel->train();
auto *kernel_info = anf_register.cnode->kernel_info();
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
kernel_info->set_kernel_mod(kernel_mod); kernel_mod->set_name(anf_register.cnode->fullname_with_scope());
// kernel->train();
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
} }
return 0; return 0;
} }
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -19,10 +19,12 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <unordered_map>
#include "include/context.h"
#include "backend/session/session_basic.h" #include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "mindspore/lite/src/train/lite_kernel_runtime.h" #include "mindspore/lite/src/train/lite_kernel_runtime.h"
#include "backend/session/session_factory.h" // #include "backend/session/session_factory.h"
namespace mindspore { namespace mindspore {
namespace lite::tensor { namespace lite::tensor {
class Tensor; class Tensor;
@ -30,36 +32,45 @@ class Tensor;
namespace session { namespace session {
struct KernelRelation { struct KernelRelation {
std::string node_full_name; std::string node_full_name;
std::vector<tensor::Tensor *> input_tensor; std::vector<lite::tensor::Tensor *> input_tensor;
std::vector<tensor::Tensor *> output_tensor; std::vector<lite::tensor::Tensor *> output_tensor;
CNodePtr cnode; CNodePtr cnode;
}; };
class TrainSession : public SessionBasic { class TrainSession {
public: public:
TrainSession() : SessionBasic() {} explicit TrainSession(lite::Context * context) { Init(context); }
~TrainSession() override = default; ~TrainSession() = default;
void Init(uint32_t device_id) override {
SessionBasic::Init(device_id);
context_ = std::make_shared<Context>(kCPUDevice, device_id);
}
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs);
// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
// {};
protected:
void Init(lite::Context *context);
std::shared_ptr<lite::Context> context_ = nullptr;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
static GraphId graph_sum_;
KernelGraphPtr NewKernelGraph();
private: private:
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(const char *model_buf, size_t size); // GraphId CompileGraph(const char *model_buf, size_t size);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);
lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);
void SetKernelInfo(const KernelGraph *kernel_graph); void SetKernelInfo(const KernelGraph *kernel_graph);
int BuildKernel(const KernelGraph *kernel_graph); int BuildKernel(const KernelGraph *kernel_graph);
lite::LiteInferKernelRuntime runtime_; lite::LiteInferKernelRuntime runtime_;
std::map<std::string, KernelRelation> kernel_relation_infos_; std::map<std::string, KernelRelation> kernel_relation_infos_;
FuncGraphPtr func_graph_ = NULL;
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
}; };
MS_REG_SESSION(kCPUDevice, TrainSession);
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_

View File

@ -87,6 +87,16 @@ file(GLOB KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc ${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc
${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc ${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc
) )
file(GLOB KERNEL_OP_TRAIN_SRC
${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp32_grad/*.cc
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
)
if (SUPPORT_TRAIN)
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
endif()
if (PLATFORM_ARM64) if (PLATFORM_ARM64)
# assembly # assembly
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s
@ -245,12 +255,13 @@ if (SUPPORT_TRAIN)
# ${SRC_DIR}/device/kernel_info.cc # ${SRC_DIR}/device/kernel_info.cc
# ${SRC_DIR}/device/kernel_runtime.cc # ${SRC_DIR}/device/kernel_runtime.cc
# ${SRC_DIR}/device/lite/kernel_runtime_extends.cc # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc
${LITE_DIR}/src/common/anf_importer/anf_importer.cc # ${LITE_DIR}/src/common/anf_importer/anf_importer.cc
${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc
${LITE_DIR}/src/ir/primitive_value.cc # ${LITE_DIR}/src/ir/primitive_value.cc
${LITE_DIR}/src/train/lite_kernel_runtime.cc # ${LITE_DIR}/src/train/lite_kernel_runtime.cc
${LITE_DIR}/src/train/train_session.cc # ${LITE_DIR}/src/train/train_session.cc
${LITE_DIR}/src/train/model_impl.cc # ${LITE_DIR}/src/train/model_impl.cc
${LITE_DIR}/src/lite_session.cc # temporary
) )
else() else()
set(TEST_LITE_SRC set(TEST_LITE_SRC
@ -265,6 +276,10 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
) )
file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
)
set(TEST_SRC set(TEST_SRC
${TEST_LITE_SRC} ${TEST_LITE_SRC}
${TEST_MINDDATA_SRC} ${TEST_MINDDATA_SRC}
@ -278,7 +293,9 @@ set(TEST_SRC
if (SUPPORT_TRAIN) if (SUPPORT_TRAIN)
set(TEST_SRC set(TEST_SRC
${TEST_SRC} ${TEST_SRC}
${TEST_DIR}/ut/src/train_test.cc ${TEST_CASE_KERNEL_TRAIN_SRC}
# ${TEST_DIR}/ut/src/train_test.cc
${TEST_DIR}/ut/src/infer_test.cc # temporary
) )
else() else()
set(TEST_SRC set(TEST_SRC

View File

@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -25,7 +24,7 @@
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/ir/tensor.h" #include "mindspore/lite/src/ir/tensor.h"
#include "mindspore/lite/src/lite_kernel.h" #include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h"
namespace mindspore { namespace mindspore {
class TestActGradFp32 : public mindspore::Common { class TestActGradFp32 : public mindspore::Common {

View File

@ -21,7 +21,7 @@
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
#include "src/common/file_utils_ext.h" #include "src/common/file_utils_ext.h"
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"
namespace mindspore { namespace mindspore {

View File

@ -18,7 +18,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "common/common_test.h" #include "common/common_test.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h"
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"
namespace mindspore { namespace mindspore {

View File

@ -21,8 +21,8 @@
#include "common/common_test.h" #include "common/common_test.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
#include "src/common/file_utils_ext.h" #include "src/common/file_utils_ext.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h"
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h"
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"

View File

@ -22,8 +22,8 @@
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32/pooling_grad.h" #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h"
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
namespace mindspore { namespace mindspore {
class TestPoolingGradFp32 : public mindspore::Common { class TestPoolingGradFp32 : public mindspore::Common {

View File

@ -0,0 +1,92 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h"
#include "src/kernel_registry.h"
namespace mindspore {
class TestSoftmaxCrossEntropyFp32 : public mindspore::Common {
public:
TestSoftmaxCrossEntropyFp32() {}
};
TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
// prepare stage
SoftmaxCrossEntropyParameter *sce_param = new SoftmaxCrossEntropyParameter();
size_t input_size;
std::string input_path = "./test_data/operators/sce_fp32_1_y_6_4.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
std::vector<int> dim_y({6, 4});
lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y);
y_tensor.SetData(input_data);
std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin";
auto ll_labels = reinterpret_cast<int64 *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size));
auto labels = new int[6];
for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]);
std::vector<int> dim_l({6});
lite::tensor::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l);
l_tensor.SetData(labels);
std::vector<lite::tensor::Tensor *> inputs = {&y_tensor, &l_tensor};
auto loss = new float[1];
std::vector<int> dim_dw({1});
lite::tensor::Tensor loss_tensor(TypeId::kNumberTypeFloat32, dim_dw);
loss_tensor.SetData(loss);
auto grad = new float[24];
lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y);
grad_tensor.SetData(grad);
std::vector<lite::tensor::Tensor *> outputs = {&grad_tensor, &loss_tensor};
kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(sce_param), NULL, desc, nullptr);
kernel_obj->Run();
printf("==================total loss=================\n");
std::cout << loss[0] << " ," << std::endl;
printf("==================Testing Grad===============\n");
std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin";
lite::CompareOutput(loss, output_path);
((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train();
kernel_obj->Run();
printf("==================output data=================\n");
for (int i = 0; i < 12; i++) {
std::cout << grad[i] << " ,";
}
std::cout << std::endl;
std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin";
lite::CompareOutput(grad, grad_path);
delete sce_param;
l_tensor.SetData(NULL);
y_tensor.SetData(NULL);
MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed";
}
} // namespace mindspore