forked from mindspore-Ecosystem/mindspore
move grad ops to seperate folders
This commit is contained in:
parent
b7ebe2be4b
commit
43d7d3af55
7
build.sh
7
build.sh
|
@ -109,7 +109,7 @@ checkopts()
|
||||||
ENABLE_GPU="off"
|
ENABLE_GPU="off"
|
||||||
|
|
||||||
# Process the options
|
# Process the options
|
||||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt
|
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:EnT:' opt
|
||||||
do
|
do
|
||||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||||
case "${opt}" in
|
case "${opt}" in
|
||||||
|
@ -282,6 +282,11 @@ checkopts()
|
||||||
ENABLE_IBVERBS="on"
|
ENABLE_IBVERBS="on"
|
||||||
echo "enable IBVERBS for parameter server"
|
echo "enable IBVERBS for parameter server"
|
||||||
;;
|
;;
|
||||||
|
T)
|
||||||
|
check_on_off $OPTARG T
|
||||||
|
SUPPORT_TRAIN=$OPTARG
|
||||||
|
echo "support train on device "
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown option ${opt}!"
|
echo "Unknown option ${opt}!"
|
||||||
usage
|
usage
|
||||||
|
|
|
@ -23,6 +23,7 @@ endif()
|
||||||
|
|
||||||
if (SUPPORT_TRAIN)
|
if (SUPPORT_TRAIN)
|
||||||
set(ANF_SRC
|
set(ANF_SRC
|
||||||
|
${ANF_SRC}
|
||||||
# ${CCSRC_DIR}/common/trans.cc
|
# ${CCSRC_DIR}/common/trans.cc
|
||||||
# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc
|
# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc
|
||||||
# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc
|
# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc
|
||||||
|
@ -40,14 +41,17 @@ if (SUPPORT_TRAIN)
|
||||||
set(LITE_SRC
|
set(LITE_SRC
|
||||||
${LITE_SRC}
|
${LITE_SRC}
|
||||||
${ANF_SRC}
|
${ANF_SRC}
|
||||||
${PASS_SRC}
|
# ${PASS_SRC}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc
|
# ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary
|
||||||
)
|
)
|
||||||
|
|
||||||
else ()
|
else ()
|
||||||
set(LITE_SRC
|
set(LITE_SRC
|
||||||
${LITE_SRC}
|
${LITE_SRC}
|
||||||
|
|
|
@ -27,26 +27,30 @@
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
void AnfImporterFromMetaGraph::ConverterConstTensor() {
|
void AnfImporterFromMetaGraph::ConverterConstTensor() {
|
||||||
MS_EXCEPTION_IF_NULL(model);
|
MS_EXCEPTION_IF_NULL(model_);
|
||||||
auto *meta_graph = model->GetMetaGraph();
|
auto *meta_graph = model_->GetMetaGraph();
|
||||||
MS_EXCEPTION_IF_NULL(meta_graph);
|
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||||
for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) {
|
num_of_tensors_ = meta_graph->allTensors()->size();
|
||||||
|
for (size_t i = 0; i < num_of_tensors_; i++) {
|
||||||
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
if (tensor->nodeType() != schema::NodeType_ValueNode) {
|
if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_ASSERT(tensor->dims() != nullptr);
|
MS_ASSERT(tensor->dims() != nullptr);
|
||||||
auto parameter = model->add_parameter();
|
auto parameter = model_->add_parameter();
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
for (size_t j = 0; j < tensor->dims()->size(); ++j) {
|
for (size_t j = 0; j < tensor->dims()->size(); ++j) {
|
||||||
shape.push_back(tensor->dims()->data()[j]);
|
shape.push_back(tensor->dims()->data()[j]);
|
||||||
}
|
}
|
||||||
auto type_id = static_cast<TypeId>(tensor->dataType());
|
auto type_id = static_cast<TypeId>(tensor->dataType()); // todo: check error
|
||||||
auto type_ptr = TypeIdToType(type_id);
|
auto type_ptr = TypeIdToType(type_id);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
auto abstractBase = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||||
parameter->set_abstract(abstract_tensor);
|
// XXX TODO copy format
|
||||||
|
parameter->set_abstract(abstractBase);
|
||||||
|
parameter->set_name(std::string("Parameter"));
|
||||||
|
|
||||||
|
if (tensor->nodeType() == schema::NodeType_ValueNode) {
|
||||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||||
MS_EXCEPTION_IF_NULL(param_value);
|
MS_EXCEPTION_IF_NULL(param_value);
|
||||||
param_value->set_tensor_shape(shape);
|
param_value->set_tensor_shape(shape);
|
||||||
|
@ -60,63 +64,106 @@ void AnfImporterFromMetaGraph::ConverterConstTensor() {
|
||||||
param_value->set_tensor_size(size);
|
param_value->set_tensor_size(size);
|
||||||
}
|
}
|
||||||
parameter->set_default_param(param_value);
|
parameter->set_default_param(param_value);
|
||||||
|
}
|
||||||
AddNode(i, parameter);
|
AddNode(i, parameter);
|
||||||
|
model_->AddAnfNode(i, parameter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfImporterFromMetaGraph::ConverterCNode() {
|
int AnfImporterFromMetaGraph::ConverterCNode() {
|
||||||
MS_EXCEPTION_IF_NULL(model);
|
MS_EXCEPTION_IF_NULL(model_);
|
||||||
auto *meta_graph = model->GetMetaGraph();
|
auto *meta_graph = model_->GetMetaGraph();
|
||||||
MS_EXCEPTION_IF_NULL(meta_graph);
|
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||||
auto cNodes = meta_graph->nodes();
|
|
||||||
for (size_t i = 0; i < cNodes->size(); i++) {
|
|
||||||
auto cNode = cNodes->GetAs<schema::CNode>(i);
|
|
||||||
MS_EXCEPTION_IF_NULL(cNode);
|
|
||||||
auto tensor_id = cNode->outputIndex()->data()[0];
|
|
||||||
if (GetNode(tensor_id)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto prim = std::make_shared<PrimitiveValue>(model->GetOp(cNode->name()->str()));
|
// Crate CNode -- Order of inputs is as follows
|
||||||
|
// First input should be the Primitive
|
||||||
|
// Then we have CNodes that contribute to this CNode
|
||||||
|
// Finally we Have the parameters
|
||||||
|
|
||||||
|
// first itteration -- create CNode with primitive, create originator map
|
||||||
|
for (size_t i = 0; i < meta_graph->nodes()->size(); i++) {
|
||||||
|
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||||
|
MS_EXCEPTION_IF_NULL(cNode);
|
||||||
|
auto prim = std::make_shared<PrimitiveValue>(model_->GetOp(cNode->name()->str()));
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr";
|
MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto value_node = NewValueNode(prim);
|
auto value_node = NewValueNode(prim);
|
||||||
AddNode(tensor_id, value_node);
|
// auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str());
|
||||||
|
// value_node->set_fullname_with_scope(prim_name);
|
||||||
std::vector<AnfNodePtr> op_inputs = {value_node};
|
std::vector<AnfNodePtr> op_inputs = {value_node};
|
||||||
|
|
||||||
|
auto cnode = model_->NewCNode(op_inputs);
|
||||||
|
auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i);
|
||||||
|
cnode->set_fullname_with_scope(node_name);
|
||||||
|
AddNode(num_of_tensors_ + i, cnode);
|
||||||
|
|
||||||
|
for (size_t j = 0; j < cNode->outputIndex()->size(); j++) {
|
||||||
|
int tensor_id = cNode->outputIndex()->data()[j];
|
||||||
|
originator_[tensor_id] = cnode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// second itteration -- fill in input CNodes and Parameters
|
||||||
|
// populate map
|
||||||
|
for (size_t i = 0; i < meta_graph->nodes()->size(); i++) {
|
||||||
|
std::vector<int> input;
|
||||||
|
std::vector<int> output;
|
||||||
|
int tensor_id;
|
||||||
|
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||||
|
MS_EXCEPTION_IF_NULL(cNode);
|
||||||
|
auto cnode = std::dynamic_pointer_cast<CNode>(GetNode(num_of_tensors_ + i));
|
||||||
|
|
||||||
|
for (size_t j = 0; j < cNode->outputIndex()->size(); j++) {
|
||||||
|
tensor_id = cNode->outputIndex()->data()[j];
|
||||||
|
output.push_back(tensor_id);
|
||||||
|
}
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(cNode->inputIndex());
|
MS_EXCEPTION_IF_NULL(cNode->inputIndex());
|
||||||
for (size_t j = 0; j < cNode->inputIndex()->size(); j++) {
|
for (size_t j = 0; j < cNode->inputIndex()->size(); j++) {
|
||||||
auto node = GetNode(*(cNode->inputIndex()->GetAs<uint32_t>(j)));
|
tensor_id = cNode->inputIndex()->data()[j];
|
||||||
if (nullptr == node) {
|
input.push_back(tensor_id);
|
||||||
MS_LOG(ERROR) << "Can't find input node.";
|
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id);
|
||||||
return RET_ERROR;
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) {
|
||||||
|
cnode->add_input(originator_[tensor_id]);
|
||||||
}
|
}
|
||||||
// todo: CheckInputNodeType, the first node should be op;
|
|
||||||
op_inputs.push_back(node);
|
|
||||||
}
|
}
|
||||||
auto cnode = model->NewCNode(op_inputs);
|
// finally add all the Parameters (which are ValueNodes)
|
||||||
auto node_name = std::string(cNode->name()->c_str());
|
for (size_t j = 0; j < cNode->inputIndex()->size(); j++) {
|
||||||
cnode->set_fullname_with_scope(node_name);
|
tensor_id = cNode->inputIndex()->data()[j];
|
||||||
AddNode(tensor_id, cnode);
|
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) {
|
||||||
|
cnode->add_input(GetNode(tensor_id));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output);
|
||||||
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnfImporterFromMetaGraph::AddReturnCNode() {
|
void AnfImporterFromMetaGraph::AddReturnCNode() {
|
||||||
MS_EXCEPTION_IF_NULL(model);
|
MS_EXCEPTION_IF_NULL(model_);
|
||||||
auto *meta_graph = model->GetMetaGraph();
|
auto *meta_graph = model_->GetMetaGraph();
|
||||||
MS_EXCEPTION_IF_NULL(meta_graph);
|
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||||
|
std::vector<int> input;
|
||||||
|
std::vector<int> output;
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
auto value_node = NewValueNode(prim::kPrimReturn);
|
auto value_node = NewValueNode(prim::kPrimReturn);
|
||||||
|
// value_node->set_fullname_with_scope("Primitive");
|
||||||
op_inputs.push_back(value_node);
|
op_inputs.push_back(value_node);
|
||||||
auto tensor_id = meta_graph->outputIndex()->data()[0];
|
for (int i = 0; i < meta_graph->outputIndex()->size(); i++) {
|
||||||
op_inputs.push_back(GetNode(tensor_id));
|
auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]];
|
||||||
auto cnode = model->NewCNode(op_inputs);
|
if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode);
|
||||||
|
input.push_back(meta_graph->outputIndex()->data()[i]);
|
||||||
|
}
|
||||||
|
auto cnode = model_->NewCNode(op_inputs);
|
||||||
cnode->set_fullname_with_scope("return");
|
cnode->set_fullname_with_scope("return");
|
||||||
model->set_return(cnode);
|
model_->set_return(cnode);
|
||||||
|
model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output);
|
||||||
}
|
}
|
||||||
FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; }
|
FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; }
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_
|
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "src/train/model_impl.h"
|
#include "src/train/model_impl.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/common/anf_importer/anf_importer.h"
|
#include "src/common/anf_importer/anf_importer.h"
|
||||||
|
@ -25,7 +26,7 @@
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class AnfImporterFromMetaGraph : public AnfImporter {
|
class AnfImporterFromMetaGraph : public AnfImporter {
|
||||||
public:
|
public:
|
||||||
explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model(model) {}
|
explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model_(model) {}
|
||||||
|
|
||||||
~AnfImporterFromMetaGraph() override = default;
|
~AnfImporterFromMetaGraph() override = default;
|
||||||
|
|
||||||
|
@ -39,9 +40,10 @@ class AnfImporterFromMetaGraph : public AnfImporter {
|
||||||
void AddReturnCNode() override;
|
void AddReturnCNode() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<ModelImpl> model = nullptr;
|
std::shared_ptr<ModelImpl> model_ = nullptr;
|
||||||
|
std::map<int, AnfNodePtr> originator_;
|
||||||
|
int num_of_tensors_ = 0;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_
|
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ class LiteKernel {
|
||||||
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||||
const lite::Primitive *primitive)
|
const lite::Primitive *primitive)
|
||||||
: opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive),
|
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
|
||||||
context_(ctx) {
|
context_(ctx) {
|
||||||
this->in_kernel_.clear();
|
this->in_kernel_.clear();
|
||||||
this->out_kernel_.clear();
|
this->out_kernel_.clear();
|
||||||
|
@ -136,7 +136,7 @@ class LiteKernel {
|
||||||
std::vector<lite::tensor::Tensor *> outputs_;
|
std::vector<lite::tensor::Tensor *> outputs_;
|
||||||
std::vector<LiteKernel *> in_kernel_;
|
std::vector<LiteKernel *> in_kernel_;
|
||||||
std::vector<LiteKernel *> out_kernel_;
|
std::vector<LiteKernel *> out_kernel_;
|
||||||
bool train_mode;
|
bool train_mode = false;
|
||||||
bool need_reinit = false;
|
bool need_reinit = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifdef SUPPORT_TRAIN
|
// #ifdef SUPPORT_TRAIN
|
||||||
#include "src/train/model_impl.h"
|
// #include "src/train/model_impl.h"
|
||||||
#else
|
// #else
|
||||||
#include "src/model_impl.h"
|
#include "src/model_impl.h"
|
||||||
#endif
|
// #endif
|
||||||
#include "include/model.h"
|
#include "include/model.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,13 @@ file(GLOB KERNEL_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (SUPPORT_TRAIN)
|
||||||
|
file (GLOB TRAIN_KERNEL_SRC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/nnacl/fp32_grad/*.cc
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (PLATFORM_ARM64)
|
if (PLATFORM_ARM64)
|
||||||
# assembly
|
# assembly
|
||||||
file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s
|
file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s
|
||||||
|
@ -27,5 +34,5 @@ if (PLATFORM_ARM32)
|
||||||
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
|
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC})
|
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})
|
||||||
add_subdirectory(nnacl)
|
add_subdirectory(nnacl)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/activation_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/activation_grad.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/runtime_api.h"
|
#include "src/runtime/runtime_api.h"
|
||||||
|
@ -102,6 +102,8 @@ kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::t
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
delete kernel;
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -48,4 +48,4 @@ class ActivationGradCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_
|
|
@ -16,9 +16,9 @@
|
||||||
|
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h"
|
||||||
#include "src/runtime/kernel/arm/fp32/arithmetic_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -88,4 +88,4 @@ class ArithmeticGradCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/runtime/kernel/arm/fp32/bias_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/bias_grad.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
|
@ -18,8 +18,8 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_factory.h"
|
#include "src/kernel_factory.h"
|
||||||
#include "src/runtime/kernel/arm/fp32/bngrad_input.h"
|
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
|
||||||
#include "src/runtime//kernel/arm/nnacl/batch_norm.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
@ -54,10 +54,6 @@ int BNGradInputCPUKernel::Init() {
|
||||||
|
|
||||||
int BNGradInputCPUKernel::ReSize() { return RET_OK; }
|
int BNGradInputCPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
/*
|
|
||||||
according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm
|
|
||||||
*/
|
|
||||||
|
|
||||||
int BNGradInputCPUKernel::Run() {
|
int BNGradInputCPUKernel::Run() {
|
||||||
// std::cout << "run succ" << std::endl;
|
// std::cout << "run succ" << std::endl;
|
||||||
auto *input_x = inputs_.at(0);
|
auto *input_x = inputs_.at(0);
|
||||||
|
@ -107,6 +103,8 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tens
|
||||||
if (RET_OK != ret) {
|
if (RET_OK != ret) {
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
delete kernel;
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -39,4 +39,4 @@ class BNGradInputCPUKernel : public LiteKernel {
|
||||||
int workspace_size;
|
int workspace_size;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h"
|
#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/pack.h"
|
#include "src/runtime/kernel/arm/nnacl/pack.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/pack_ext.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -39,4 +39,4 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/convolution_grad_input.h"
|
#include "src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/pack.h"
|
#include "src/runtime/kernel/arm/nnacl/pack.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/pack_ext.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -39,4 +39,4 @@ class ConvolutionGradInputCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/fp32/opt_momentum.h"
|
#include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/pooling_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -48,4 +48,4 @@ class PoolingGradCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/power_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/power_grad.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
@ -47,4 +47,4 @@ class PowerGradCPUKernel : public LiteKernel {
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_
|
|
@ -14,13 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h"
|
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
|
|
||||||
#include "schema/model_generated.h"
|
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
|
#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h"
|
||||||
|
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
|
||||||
|
#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
||||||
using mindspore::lite::KernelRegistrar;
|
using mindspore::lite::KernelRegistrar;
|
||||||
using mindspore::lite::RET_ERROR;
|
using mindspore::lite::RET_ERROR;
|
||||||
using mindspore::lite::RET_OK;
|
using mindspore::lite::RET_OK;
|
||||||
|
@ -73,7 +72,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la
|
||||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
||||||
auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data());
|
auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data());
|
||||||
auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data());
|
auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data());
|
||||||
auto out = reinterpret_cast<float *>(outputs_.at(0)->Data());
|
auto out = reinterpret_cast<float *>(outputs_.at(1)->Data());
|
||||||
float *grads = NULL;
|
float *grads = NULL;
|
||||||
if (is_train()) { // outputs_.size() > 1)
|
if (is_train()) { // outputs_.size() > 1)
|
||||||
grads = reinterpret_cast<float *>(outputs_.at(0)->Data());
|
grads = reinterpret_cast<float *>(outputs_.at(0)->Data());
|
||||||
|
@ -90,10 +89,11 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
||||||
SoftmaxParameter sm_params;
|
SoftmaxParameter sm_params;
|
||||||
sm_params.n_dim_ = param->n_dim_;
|
sm_params.n_dim_ = param->n_dim_;
|
||||||
sm_params.element_size_ = data_size;
|
sm_params.element_size_ = data_size;
|
||||||
sm_params.axis_ = 1;
|
sm_params.axis_ = 0;
|
||||||
for (int i = 0; i < 4; i++) // softmax has only 4 params in shape
|
for (int i = 0; i < 4; i++) // softmax has only 4 params in shape
|
||||||
sm_params.input_shape_[i] = param->input_shape_[i];
|
sm_params.input_shape_[i] = param->input_shape_[i];
|
||||||
float sum_data[sm_params.input_shape_[sm_params.axis_]];
|
float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0};
|
||||||
|
std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0);
|
||||||
Softmax(ins, losses, sum_data, &sm_params);
|
Softmax(ins, losses, sum_data, &sm_params);
|
||||||
|
|
||||||
if (is_train()) {
|
if (is_train()) {
|
|
@ -20,7 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h"
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
@ -30,8 +30,7 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel {
|
||||||
explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter,
|
explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter,
|
||||||
const std::vector<lite::tensor::Tensor *> &inputs,
|
const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||||
const lite::Context *ctx,
|
const lite::Context *ctx, const lite::Primitive *primitive)
|
||||||
const lite::Primitive *primitive)
|
|
||||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||||
param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter);
|
param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter);
|
||||||
}
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
||||||
|
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
|
||||||
|
#include "src/runtime/kernel/arm/opclib/errorcode.h"
|
||||||
|
|
||||||
|
struct ActivationGradParameter {
|
||||||
|
OpParameter op_parameter{};
|
||||||
|
int type_;
|
||||||
|
float alpha_{0.01};
|
||||||
|
};
|
||||||
|
|
||||||
|
inline int ReluGrad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
dst[i] = src1[i] > 0 ? 1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
ElementMul(src0, dst, dst, length);
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int Relu6Grad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
if (src1[i] < 0) {
|
||||||
|
dst[i] = 0;
|
||||||
|
} else {
|
||||||
|
dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ElementMul(src0, dst, dst, length);
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
dst[i] = src1[i] > 0.0f ? 1.0f : alpha;
|
||||||
|
}
|
||||||
|
ElementMul(src0, dst, dst, length);
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
dst[i] = src0[i] * (src1[i] * (1.0f - src1[i]));
|
||||||
|
}
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int TanhGrad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i];
|
||||||
|
}
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int HSwishGrad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f));
|
||||||
|
dst[i] = tmp * src0[i];
|
||||||
|
}
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f));
|
||||||
|
dst[i] = tmp * src0[i];
|
||||||
|
}
|
||||||
|
return OPCLIB_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h"
|
||||||
|
|
||||||
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) {
|
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
|
@ -13,8 +13,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
|
||||||
|
|
||||||
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size);
|
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size);
|
||||||
void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size);
|
void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size);
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "src/runtime/kernel/arm/nnacl/batch_norm.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h"
|
||||||
|
|
||||||
static void sumSpatialBatch(const float *in, int size, int ch, float *out) {
|
static void sumSpatialBatch(const float *in, int size, int ch, float *out) {
|
||||||
std::fill(out, out + ch, 0.f);
|
std::fill(out, out + ch, 0.f);
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_
|
||||||
#define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_
|
||||||
|
|
||||||
struct bnParameter {
|
struct bnParameter {
|
||||||
int batch;
|
int batch;
|
||||||
|
@ -36,4 +36,5 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int
|
||||||
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
|
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
|
||||||
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta);
|
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta);
|
||||||
|
|
||||||
#endif
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/gemm.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h"
|
||||||
|
|
||||||
static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c,
|
static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c,
|
||||||
int ldc) {
|
int ldc) {
|
|
@ -14,10 +14,10 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_
|
||||||
|
|
||||||
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
|
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
|
||||||
int ldb, float beta, float *mat_c, int ldc);
|
int ldb, float beta, float *mat_c, int ldc);
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include "src/runtime/kernel/arm/nnacl/pack_ext.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h"
|
||||||
|
|
||||||
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
|
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
|
||||||
|
|
||||||
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
|
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
|
||||||
int stride_w = pooling_param->stride_w_;
|
int stride_w = pooling_param->stride_w_;
|
|
@ -14,12 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
|
||||||
|
|
||||||
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
|
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
|
||||||
void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
|
void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h"
|
||||||
|
|
||||||
static inline bool NextIndex(const int num_dims, const int *dims, int *current) {
|
static inline bool NextIndex(const int num_dims, const int *dims, int *current) {
|
||||||
int carry = 1;
|
int carry = 1;
|
|
@ -57,4 +57,3 @@ std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefTo
|
||||||
return multiTensor;
|
return multiTensor;
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -16,16 +16,15 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "base/base_ref.h"
|
#include "utils/base_ref.h"
|
||||||
#include "include/ms_tensor.h"
|
#include "include/ms_tensor.h"
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
#ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
|
||||||
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
#define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);
|
std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);
|
||||||
|
|
||||||
std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor(
|
std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor(
|
||||||
const VectorRef &vector_ref);
|
const VectorRef &vector_ref);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
#endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
|
#include "src/train/lite_kernel_runtime.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) {
|
std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) {
|
||||||
std::vector<CNodePtr> graph_inputs;
|
std::vector<CNodePtr> graph_inputs;
|
||||||
|
@ -34,7 +35,8 @@ std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<C
|
||||||
}
|
}
|
||||||
|
|
||||||
void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
|
void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
const std::vector<tensor::Tensor *> &inputs,
|
||||||
|
std::vector<tensor::Tensor *> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto execution_order = graph->execution_order();
|
auto execution_order = graph->execution_order();
|
||||||
auto graph_inputs = GetGraphInputs(execution_order);
|
auto graph_inputs = GetGraphInputs(execution_order);
|
||||||
|
@ -56,15 +58,17 @@ void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
|
||||||
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input));
|
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input));
|
||||||
auto output_tensors = liteKernel->GetOutputs();
|
auto output_tensors = liteKernel->GetOutputs();
|
||||||
for (auto output_tensor : output_tensors) {
|
for (auto output_tensor : output_tensors) {
|
||||||
tensor::TensorPtr output_tensor_ptr(output_tensor);
|
// tensor::TensorPtr output_tensor_ptr(output_tensor);
|
||||||
outputs->push_back(output_tensor_ptr);
|
outputs->push_back(output_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) {
|
bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
|
||||||
|
std::vector<tensor::Tensor *> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
BindInputOutput(graph, inputs, *outputs);
|
||||||
std::vector<kernel::LiteKernel *> kernels;
|
std::vector<kernel::LiteKernel *> kernels;
|
||||||
auto nodes = graph->execution_order();
|
auto nodes = graph->execution_order();
|
||||||
for (const auto &node : nodes) {
|
for (const auto &node : nodes) {
|
||||||
|
@ -76,8 +80,7 @@ bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
kernel::LiteKernelUtil::TopologicalSortKernels(kernels);
|
kernel::LiteKernelUtil::TopologicalSortKernels(kernels);
|
||||||
Executor executor;
|
Executor executor;
|
||||||
auto ret = executor.Run(kernels);
|
auto ret = executor.Run(inputs, *outputs, kernels);
|
||||||
return 0 == ret;
|
return 0 == ret;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -23,35 +23,28 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "src/runtime/allocator.h"
|
#include "src/runtime/allocator.h"
|
||||||
#include "src/executor.h"
|
#include "src/executor.h"
|
||||||
#include "runtime/device/kernel_runtime.h"
|
// #include "runtime/device/kernel_runtime.h"
|
||||||
#include "runtime/device/device_address.h"
|
#include "runtime/device/device_address.h"
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class LiteInferKernelRuntime : public device::KernelRuntime {
|
class LiteInferKernelRuntime {
|
||||||
public:
|
public:
|
||||||
LiteInferKernelRuntime() = default;
|
LiteInferKernelRuntime() = default;
|
||||||
~LiteInferKernelRuntime() override = default;
|
~LiteInferKernelRuntime() = default;
|
||||||
|
|
||||||
bool Init() override { return true; }
|
bool Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
|
||||||
|
std::vector<tensor::Tensor *> *outputs);
|
||||||
void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
|
|
||||||
VectorRef *outputs);
|
|
||||||
|
|
||||||
bool Run(session::KernelGraph *graph);
|
|
||||||
|
|
||||||
void AssignKernelAddress(session::KernelGraph *graph) {}
|
void AssignKernelAddress(session::KernelGraph *graph) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
|
||||||
|
std::vector<tensor::Tensor *> *outputs);
|
||||||
|
|
||||||
std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order);
|
std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order);
|
||||||
bool SyncStream() override { return true; };
|
|
||||||
device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
|
||||||
TypeId type_id) override {
|
|
||||||
return nullptr;
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_
|
#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,34 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "src/train/model_impl.h"
|
#include "src/train/model_impl.h"
|
||||||
#include "schema/model_generated.h"
|
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/common/anf_importer/import_from_meta_graph.h"
|
||||||
|
|
||||||
namespace mindspore::lite::train {
|
namespace mindspore::lite::train {
|
||||||
|
|
||||||
|
std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(model_buf);
|
||||||
|
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||||
|
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||||
|
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// todo hangangqiang remove when copy primitive done
|
||||||
|
auto *inner_buf = new char[size];
|
||||||
|
memcpy(inner_buf, model_buf, size);
|
||||||
|
auto meta_graph = schema::GetMetaGraph(inner_buf);
|
||||||
|
auto func_graph_model = std::make_shared<ModelImpl>(meta_graph);
|
||||||
|
auto ret = func_graph_model->BuildOps();
|
||||||
|
if (0 != ret) {
|
||||||
|
MS_LOG(ERROR) << "BuildOps failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AnfImporterFromMetaGraph anfImporter(func_graph_model);
|
||||||
|
anfImporter.Import();
|
||||||
|
return func_graph_model;
|
||||||
|
}
|
||||||
|
|
||||||
const lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
|
const lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
|
||||||
auto iter = ops.find(name);
|
auto iter = ops.find(name);
|
||||||
if (iter == ops.end()) {
|
if (iter == ops.end()) {
|
||||||
|
@ -98,6 +121,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
|
||||||
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim));
|
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim));
|
||||||
case schema::PrimitiveType_Nhwc2Nchw:
|
case schema::PrimitiveType_Nhwc2Nchw:
|
||||||
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim));
|
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim));
|
||||||
|
case schema::PrimitiveType_MatMul:
|
||||||
|
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -115,5 +140,6 @@ int ModelImpl::BuildOps() {
|
||||||
auto srcPrim = cNode->primitive();
|
auto srcPrim = cNode->primitive();
|
||||||
this->ops[name] = CopyPrimitive(srcPrim);
|
this->ops[name] = CopyPrimitive(srcPrim);
|
||||||
}
|
}
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite::train
|
} // namespace mindspore::lite::train
|
||||||
|
|
|
@ -15,11 +15,12 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
|
#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
|
||||||
#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H
|
#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/ops/ops.h"
|
#include "src/ops/ops.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
|
@ -28,7 +29,7 @@ namespace mindspore::lite {
|
||||||
namespace train {
|
namespace train {
|
||||||
class ModelImpl : public FuncGraph {
|
class ModelImpl : public FuncGraph {
|
||||||
public:
|
public:
|
||||||
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size);
|
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); // { return NULL; };
|
||||||
ModelImpl() = default;
|
ModelImpl() = default;
|
||||||
explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {}
|
explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {}
|
||||||
~ModelImpl() override = default;
|
~ModelImpl() override = default;
|
||||||
|
@ -37,16 +38,27 @@ class ModelImpl : public FuncGraph {
|
||||||
void FreeMetaGraph();
|
void FreeMetaGraph();
|
||||||
int BuildOps();
|
int BuildOps();
|
||||||
|
|
||||||
|
void AddCNodeInputOutput(std::string name, const std::vector<int> &input, const std::vector<int> &output) {
|
||||||
|
std::vector<int> *tuple = new std::vector<int>[2];
|
||||||
|
tuple[0] = input;
|
||||||
|
tuple[1] = output;
|
||||||
|
connectivity_[name] = tuple;
|
||||||
|
}
|
||||||
|
std::vector<int> *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; }
|
||||||
|
void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; }
|
||||||
|
AnfNodePtr GetAnfNode(int id) { return tensors_[id]; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim);
|
lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const schema::MetaGraph *meta_graph = nullptr;
|
const schema::MetaGraph *meta_graph = nullptr;
|
||||||
|
std::map<int, AnfNodePtr> tensors_;
|
||||||
|
std::map<std::string, std::vector<int> *> connectivity_;
|
||||||
std::map<std::string, lite::Primitive *> ops;
|
std::map<std::string, lite::Primitive *> ops;
|
||||||
};
|
};
|
||||||
} // namespace train
|
} // namespace train
|
||||||
using ModelImpl = mindspore::lite::train::ModelImpl;
|
using ModelImpl = mindspore::lite::train::ModelImpl;
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_INCLUDE_MODEL_H
|
#endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,253 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include "src/train/train_anf_session.h"
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
|
||||||
|
#include "mindspore/lite/src/train/train_session.h"
|
||||||
|
#include "mindspore/lite/src/kernel_factory.h"
|
||||||
|
#include "mindspore/lite/src/param_value_lite.h"
|
||||||
|
#include "common/utils.h"
|
||||||
|
#include "mindspore/lite/src/ops/ops.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "mindspore/lite/src/ir/tensor.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "src/ir/primitive_value.h"
|
||||||
|
#include "src/train/model_impl.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace session {
|
||||||
|
static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) {
|
||||||
|
auto nodeAbstract = anfNodePtr->abstract();
|
||||||
|
if (nodeAbstract != nullptr) {
|
||||||
|
auto shape = nodeAbstract->GetShapeTrack();
|
||||||
|
if (!shape->isa<abstract::Shape>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Not a Shape";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto dims = dyn_cast<abstract::Shape>(shape)->shape();
|
||||||
|
return dims;
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "abstract is nullptr, return empty dims";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
|
||||||
|
auto nodeAbstract = anfNodePtr->abstract();
|
||||||
|
if (nodeAbstract != nullptr) {
|
||||||
|
return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC";
|
||||||
|
return schema::Format_NHWC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
|
||||||
|
auto nodeAbstract = anfNodePtr->abstract();
|
||||||
|
if (nodeAbstract != nullptr) {
|
||||||
|
return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
|
||||||
|
return TypeId::kTypeUnknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TrainANFSession::Init(lite::Context *context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
|
||||||
|
}
|
||||||
|
|
||||||
|
lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
|
||||||
|
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
|
||||||
|
if (out_tensor == NULL) {
|
||||||
|
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
|
||||||
|
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
|
||||||
|
tensors_[anf_node] = out_tensor;
|
||||||
|
}
|
||||||
|
return out_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
|
||||||
|
auto return_node = kernel_graph->get_return();
|
||||||
|
auto node_list = TopoSort(return_node);
|
||||||
|
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
KernelRelation kernel_relation;
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
kernel_relation.node_full_name = cnode->fullname_with_scope();
|
||||||
|
kernel_relation.cnode = cnode;
|
||||||
|
std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
|
||||||
|
if (cnode_io_indices == NULL) {
|
||||||
|
MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < cnode_io_indices[1].size(); i++) {
|
||||||
|
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
|
||||||
|
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lite::tensor::Tensor *tensor_ptr = nullptr;
|
||||||
|
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
|
||||||
|
if (cnode->input(index)->isa<CNode>()) {
|
||||||
|
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||||
|
auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()];
|
||||||
|
// todo not support multi-outputs kernel sudo as spilt
|
||||||
|
tensor_ptr = input_kernel_relation.output_tensor.front();
|
||||||
|
} else if (cnode->input(index)->isa<Parameter>()) {
|
||||||
|
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
|
||||||
|
auto para = input_parameter->default_param();
|
||||||
|
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
|
||||||
|
// auto dims = param_value->tensor_shape();
|
||||||
|
// tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
|
||||||
|
tensor_ptr = GetTensorForAnfNode(cnode->input(index));
|
||||||
|
if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
|
||||||
|
tensor_ptr->SetData(param_value->tensor_addr());
|
||||||
|
}
|
||||||
|
} else if (cnode->input(index)->isa<ValueNode>()) {
|
||||||
|
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
|
||||||
|
// tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
|
||||||
|
// GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
|
||||||
|
tensor_ptr = GetTensorForAnfNode(input_valuenode);
|
||||||
|
// todo(yankai)
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(false);
|
||||||
|
}
|
||||||
|
kernel_relation.input_tensor.push_back(tensor_ptr);
|
||||||
|
}
|
||||||
|
kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphId TrainANFSession::graph_sum_ = 0;
|
||||||
|
|
||||||
|
KernelGraphPtr TrainANFSession::NewKernelGraph() {
|
||||||
|
auto graph = std::make_shared<KernelGraph>();
|
||||||
|
graph->set_graph_id(graph_sum_);
|
||||||
|
graphs_[graph_sum_++] = graph;
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<KernelGraph> TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto graph = NewKernelGraph();
|
||||||
|
graph->set_return(func_graph->get_return());
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
std::vector<CNodePtr> cnode_order;
|
||||||
|
for (const auto &node : node_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto cn_node = node->cast<CNodePtr>();
|
||||||
|
cnode_order.push_back(cn_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graph->set_execution_order(cnode_order);
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
GraphId TrainANFSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
|
auto graph = ConstructKernelGraph(func_graph);
|
||||||
|
func_graph_ = func_graph;
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_LOG(INFO) << "Set kernel info";
|
||||||
|
SetKernelInfo(graph.get());
|
||||||
|
|
||||||
|
(void)BuildKernelInputAndOutputFromFuncGraph(graph);
|
||||||
|
MS_LOG(INFO) << "Build kernel";
|
||||||
|
auto ret = BuildKernel(graph.get());
|
||||||
|
if (0 != ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "BuildKernel failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the graph id to backend
|
||||||
|
auto graph_id = graph->graph_id();
|
||||||
|
graphs_[graph_id] = graph;
|
||||||
|
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
|
||||||
|
return graph_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
|
std::vector<lite::tensor::Tensor *> *outputs) {
|
||||||
|
auto &kernel_graph = graphs_[graph_id];
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
MS_LOG(INFO) << "Bind input output address";
|
||||||
|
// runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
|
||||||
|
// auto execution_order = kernel_graph->execution_order();
|
||||||
|
// Todo : hangangqiang
|
||||||
|
// Reorder(&execution_order);
|
||||||
|
// kernel_graph->set_execution_order(execution_order);
|
||||||
|
MS_LOG(INFO) << "Run graph start";
|
||||||
|
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, *outputs);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Run graph failed";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Run graph end";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
auto &kernel_nodes = kernel_graph->execution_order();
|
||||||
|
for (const auto &kernel_node : kernel_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
kernel_node->set_kernel_info(kernel_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) {
|
||||||
|
std::string kernel_name = iter->first;
|
||||||
|
KernelRelation anf_register = iter->second;
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_register.cnode);
|
||||||
|
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto value_node_prim = anf_register.cnode->input(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node_prim);
|
||||||
|
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
|
||||||
|
MS_EXCEPTION_IF_NULL(node_primitive);
|
||||||
|
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
|
||||||
|
if (0 != ret) {
|
||||||
|
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};
|
||||||
|
|
||||||
|
auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
|
||||||
|
node_primitive, context_.get(), desc);
|
||||||
|
if (nullptr == kernel) {
|
||||||
|
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
|
||||||
|
kernel_mod->set_name(anf_register.cnode->fullname_with_scope());
|
||||||
|
|
||||||
|
// kernel->train();
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} // namespace session
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,76 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "backend/session/session_basic.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
|
||||||
|
// #include "backend/session/session_factory.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite::tensor {
|
||||||
|
class Tensor;
|
||||||
|
}
|
||||||
|
namespace session {
|
||||||
|
struct KernelRelation {
|
||||||
|
std::string node_full_name;
|
||||||
|
std::vector<lite::tensor::Tensor *> input_tensor;
|
||||||
|
std::vector<lite::tensor::Tensor *> output_tensor;
|
||||||
|
CNodePtr cnode;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TrainANFSession {
|
||||||
|
public:
|
||||||
|
explicit TrainANFSession(lite::Context *context) { Init(context); }
|
||||||
|
~TrainANFSession() = default;
|
||||||
|
|
||||||
|
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
|
||||||
|
|
||||||
|
void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
|
std::vector<lite::tensor::Tensor *> *outputs);
|
||||||
|
|
||||||
|
// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
|
||||||
|
// {};
|
||||||
|
protected:
|
||||||
|
void Init(lite::Context *context);
|
||||||
|
std::shared_ptr<lite::Context> context_ = nullptr;
|
||||||
|
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||||
|
static GraphId graph_sum_;
|
||||||
|
KernelGraphPtr NewKernelGraph();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||||
|
// GraphId CompileGraph(const char *model_buf, size_t size);
|
||||||
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
|
||||||
|
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);
|
||||||
|
|
||||||
|
lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);
|
||||||
|
|
||||||
|
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||||
|
int BuildKernel(const KernelGraph *kernel_graph);
|
||||||
|
lite::LiteInferKernelRuntime runtime_;
|
||||||
|
std::map<std::string, KernelRelation> kernel_relation_infos_;
|
||||||
|
FuncGraphPtr func_graph_ = NULL;
|
||||||
|
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
|
||||||
|
};
|
||||||
|
} // namespace session
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
|
|
@ -15,6 +15,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "include/context.h"
|
||||||
|
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
|
||||||
#include "mindspore/lite/src/train/train_session.h"
|
#include "mindspore/lite/src/train/train_session.h"
|
||||||
#include "mindspore/lite/src/kernel_factory.h"
|
#include "mindspore/lite/src/kernel_factory.h"
|
||||||
#include "mindspore/lite/src/param_value_lite.h"
|
#include "mindspore/lite/src/param_value_lite.h"
|
||||||
|
@ -25,6 +27,7 @@
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "src/ir/primitive_value.h"
|
#include "src/ir/primitive_value.h"
|
||||||
|
#include "src/train/model_impl.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
|
@ -57,16 +60,32 @@ static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
|
||||||
static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
|
static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
|
||||||
auto nodeAbstract = anfNodePtr->abstract();
|
auto nodeAbstract = anfNodePtr->abstract();
|
||||||
if (nodeAbstract != nullptr) {
|
if (nodeAbstract != nullptr) {
|
||||||
return nodeAbstract->GetTypeTrack()->type_id();
|
return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
|
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
|
||||||
return TypeId::kTypeUnknown;
|
return TypeId::kTypeUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TrainSession::Init(lite::Context *context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
|
||||||
|
}
|
||||||
|
|
||||||
|
lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
|
||||||
|
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
|
||||||
|
if (out_tensor == NULL) {
|
||||||
|
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
|
||||||
|
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
|
||||||
|
tensors_[anf_node] = out_tensor;
|
||||||
|
}
|
||||||
|
return out_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
|
int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
|
||||||
auto return_node = kernel_graph->get_return();
|
auto return_node = kernel_graph->get_return();
|
||||||
auto node_list = TopoSort(return_node);
|
auto node_list = TopoSort(return_node);
|
||||||
|
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
|
||||||
for (auto &node : node_list) {
|
for (auto &node : node_list) {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -75,11 +94,16 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
kernel_relation.node_full_name = cnode->fullname_with_scope();
|
kernel_relation.node_full_name = cnode->fullname_with_scope();
|
||||||
kernel_relation.cnode = cnode;
|
kernel_relation.cnode = cnode;
|
||||||
auto *out_tensor =
|
std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
|
||||||
new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode),
|
if (cnode_io_indices == NULL) {
|
||||||
schema::NodeType_Parameter);
|
MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
|
||||||
kernel_relation.output_tensor.push_back(out_tensor);
|
} else {
|
||||||
tensor::Tensor *tensor_ptr = nullptr;
|
for (int i = 0; i < cnode_io_indices[1].size(); i++) {
|
||||||
|
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
|
||||||
|
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lite::tensor::Tensor *tensor_ptr = nullptr;
|
||||||
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
|
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
|
||||||
if (cnode->input(index)->isa<CNode>()) {
|
if (cnode->input(index)->isa<CNode>()) {
|
||||||
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
|
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||||
|
@ -90,17 +114,17 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
|
||||||
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
|
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
|
||||||
auto para = input_parameter->default_param();
|
auto para = input_parameter->default_param();
|
||||||
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
|
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
|
||||||
auto dims = param_value->tensor_shape();
|
// auto dims = param_value->tensor_shape();
|
||||||
tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC,
|
// tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
|
||||||
schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode
|
tensor_ptr = GetTensorForAnfNode(cnode->input(index));
|
||||||
if (param_value->tensor_size() != 0) {
|
if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
|
||||||
tensor_ptr->SetData(param_value->tensor_addr());
|
tensor_ptr->SetData(param_value->tensor_addr());
|
||||||
}
|
}
|
||||||
} else if (cnode->input(index)->isa<ValueNode>()) {
|
} else if (cnode->input(index)->isa<ValueNode>()) {
|
||||||
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
|
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
|
||||||
tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode),
|
// tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
|
||||||
schema::Format_NHWC,
|
// GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
|
||||||
schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode
|
tensor_ptr = GetTensorForAnfNode(input_valuenode);
|
||||||
// todo(yankai)
|
// todo(yankai)
|
||||||
} else {
|
} else {
|
||||||
MS_ASSERT(false);
|
MS_ASSERT(false);
|
||||||
|
@ -111,7 +135,7 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
#if 0
|
||||||
GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||||
auto graph_id = graph_sum_;
|
auto graph_id = graph_sum_;
|
||||||
auto graph = SessionBasic::ConstructKernelGraph(lst, outputs);
|
auto graph = SessionBasic::ConstructKernelGraph(lst, outputs);
|
||||||
|
@ -124,6 +148,17 @@ GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrLi
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; }
|
GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; }
|
||||||
|
#else
|
||||||
|
GraphId TrainSession::graph_sum_ = 0;
|
||||||
|
|
||||||
|
KernelGraphPtr TrainSession::NewKernelGraph() {
|
||||||
|
auto graph = std::make_shared<KernelGraph>();
|
||||||
|
graph->set_graph_id(graph_sum_);
|
||||||
|
graphs_[graph_sum_++] = graph;
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -141,14 +176,14 @@ std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphP
|
||||||
graph->set_execution_order(cnode_order);
|
graph->set_execution_order(cnode_order);
|
||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
auto graph = ConstructKernelGraph(func_graph);
|
auto graph = ConstructKernelGraph(func_graph);
|
||||||
|
func_graph_ = func_graph;
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_LOG(INFO) << "Set kernel info";
|
MS_LOG(INFO) << "Set kernel info";
|
||||||
SetKernelInfo(graph.get());
|
SetKernelInfo(graph.get());
|
||||||
|
|
||||||
(void) BuildKernelInputAndOutputFromFuncGraph(graph);
|
(void)BuildKernelInputAndOutputFromFuncGraph(graph);
|
||||||
MS_LOG(INFO) << "Build kernel";
|
MS_LOG(INFO) << "Build kernel";
|
||||||
auto ret = BuildKernel(graph.get());
|
auto ret = BuildKernel(graph.get());
|
||||||
if (0 != ret) {
|
if (0 != ret) {
|
||||||
|
@ -162,18 +197,18 @@ GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
return graph_id;
|
return graph_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Tensor *> &inputs,
|
void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
std::vector<tensor::Tensor *> &outputs) {
|
std::vector<lite::tensor::Tensor *> *outputs) {
|
||||||
auto &kernel_graph = graphs_[graph_id];
|
auto &kernel_graph = graphs_[graph_id];
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
MS_LOG(INFO) << "Bind input output address";
|
MS_LOG(INFO) << "Bind input output address";
|
||||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
// runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
|
||||||
// auto execution_order = kernel_graph->execution_order();
|
// auto execution_order = kernel_graph->execution_order();
|
||||||
// Todo : hangangqiang
|
// Todo : hangangqiang
|
||||||
// Reorder(&execution_order);
|
// Reorder(&execution_order);
|
||||||
// kernel_graph->set_execution_order(execution_order);
|
// kernel_graph->set_execution_order(execution_order);
|
||||||
MS_LOG(INFO) << "Run graph start";
|
MS_LOG(INFO) << "Run graph start";
|
||||||
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<tensor::Tensor *> &) inputs, outputs);
|
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, outputs);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Run graph failed";
|
MS_LOG(EXCEPTION) << "Run graph failed";
|
||||||
}
|
}
|
||||||
|
@ -199,34 +234,34 @@ int TrainSession::BuildKernel(const KernelGraph *kernel_graph) {
|
||||||
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
|
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
lite::Context context;
|
|
||||||
context.deviceCtx.type = lite::DeviceType::DT_CPU;
|
|
||||||
auto value_node_prim = anf_register.cnode->input(0);
|
auto value_node_prim = anf_register.cnode->input(0);
|
||||||
MS_EXCEPTION_IF_NULL(value_node_prim);
|
MS_EXCEPTION_IF_NULL(value_node_prim);
|
||||||
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
|
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto node_primitive = (lite::Primitive *) (prim->GetPrimitive());
|
auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
|
||||||
MS_EXCEPTION_IF_NULL(node_primitive);
|
MS_EXCEPTION_IF_NULL(node_primitive);
|
||||||
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
|
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
|
||||||
if (0 != ret) {
|
if (0 != ret) {
|
||||||
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
|
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()};
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};
|
||||||
|
|
||||||
auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
|
auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
|
||||||
node_primitive, &context, desc);
|
node_primitive, context_.get(), desc);
|
||||||
if (nullptr == kernel) {
|
if (nullptr == kernel) {
|
||||||
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
|
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
kernel->train();
|
|
||||||
auto *kernel_info = anf_register.cnode->kernel_info();
|
|
||||||
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
|
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
|
||||||
kernel_info->set_kernel_mod(kernel_mod);
|
kernel_mod->set_name(anf_register.cnode->fullname_with_scope());
|
||||||
|
|
||||||
|
// kernel->train();
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,12 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "include/context.h"
|
||||||
#include "backend/session/session_basic.h"
|
#include "backend/session/session_basic.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
|
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
|
||||||
#include "backend/session/session_factory.h"
|
// #include "backend/session/session_factory.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite::tensor {
|
namespace lite::tensor {
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
@ -30,36 +32,45 @@ class Tensor;
|
||||||
namespace session {
|
namespace session {
|
||||||
struct KernelRelation {
|
struct KernelRelation {
|
||||||
std::string node_full_name;
|
std::string node_full_name;
|
||||||
std::vector<tensor::Tensor *> input_tensor;
|
std::vector<lite::tensor::Tensor *> input_tensor;
|
||||||
std::vector<tensor::Tensor *> output_tensor;
|
std::vector<lite::tensor::Tensor *> output_tensor;
|
||||||
CNodePtr cnode;
|
CNodePtr cnode;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TrainSession : public SessionBasic {
|
class TrainSession {
|
||||||
public:
|
public:
|
||||||
TrainSession() : SessionBasic() {}
|
explicit TrainSession(lite::Context * context) { Init(context); }
|
||||||
~TrainSession() override = default;
|
~TrainSession() = default;
|
||||||
void Init(uint32_t device_id) override {
|
|
||||||
SessionBasic::Init(device_id);
|
|
||||||
context_ = std::make_shared<Context>(kCPUDevice, device_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
|
||||||
|
|
||||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
|
std::vector<lite::tensor::Tensor *> *outputs);
|
||||||
|
|
||||||
|
// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
|
||||||
|
// {};
|
||||||
|
protected:
|
||||||
|
void Init(lite::Context *context);
|
||||||
|
std::shared_ptr<lite::Context> context_ = nullptr;
|
||||||
|
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||||
|
static GraphId graph_sum_;
|
||||||
|
KernelGraphPtr NewKernelGraph();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
// GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||||
GraphId CompileGraph(const char *model_buf, size_t size);
|
// GraphId CompileGraph(const char *model_buf, size_t size);
|
||||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
|
||||||
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);
|
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);
|
||||||
|
|
||||||
|
lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);
|
||||||
|
|
||||||
void SetKernelInfo(const KernelGraph *kernel_graph);
|
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||||
int BuildKernel(const KernelGraph *kernel_graph);
|
int BuildKernel(const KernelGraph *kernel_graph);
|
||||||
lite::LiteInferKernelRuntime runtime_;
|
lite::LiteInferKernelRuntime runtime_;
|
||||||
std::map<std::string, KernelRelation> kernel_relation_infos_;
|
std::map<std::string, KernelRelation> kernel_relation_infos_;
|
||||||
|
FuncGraphPtr func_graph_ = NULL;
|
||||||
|
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
|
||||||
};
|
};
|
||||||
MS_REG_SESSION(kCPUDevice, TrainSession);
|
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
|
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
|
||||||
|
|
||||||
|
|
|
@ -87,6 +87,16 @@ file(GLOB KERNEL_OP_SRC
|
||||||
${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc
|
${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc
|
||||||
${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc
|
${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file(GLOB KERNEL_OP_TRAIN_SRC
|
||||||
|
${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp32_grad/*.cc
|
||||||
|
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
if (SUPPORT_TRAIN)
|
||||||
|
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
|
||||||
|
endif()
|
||||||
|
|
||||||
if (PLATFORM_ARM64)
|
if (PLATFORM_ARM64)
|
||||||
# assembly
|
# assembly
|
||||||
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s
|
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s
|
||||||
|
@ -245,12 +255,13 @@ if (SUPPORT_TRAIN)
|
||||||
# ${SRC_DIR}/device/kernel_info.cc
|
# ${SRC_DIR}/device/kernel_info.cc
|
||||||
# ${SRC_DIR}/device/kernel_runtime.cc
|
# ${SRC_DIR}/device/kernel_runtime.cc
|
||||||
# ${SRC_DIR}/device/lite/kernel_runtime_extends.cc
|
# ${SRC_DIR}/device/lite/kernel_runtime_extends.cc
|
||||||
${LITE_DIR}/src/common/anf_importer/anf_importer.cc
|
# ${LITE_DIR}/src/common/anf_importer/anf_importer.cc
|
||||||
${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc
|
# ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc
|
||||||
${LITE_DIR}/src/ir/primitive_value.cc
|
# ${LITE_DIR}/src/ir/primitive_value.cc
|
||||||
${LITE_DIR}/src/train/lite_kernel_runtime.cc
|
# ${LITE_DIR}/src/train/lite_kernel_runtime.cc
|
||||||
${LITE_DIR}/src/train/train_session.cc
|
# ${LITE_DIR}/src/train/train_session.cc
|
||||||
${LITE_DIR}/src/train/model_impl.cc
|
# ${LITE_DIR}/src/train/model_impl.cc
|
||||||
|
${LITE_DIR}/src/lite_session.cc # temporary
|
||||||
)
|
)
|
||||||
else()
|
else()
|
||||||
set(TEST_LITE_SRC
|
set(TEST_LITE_SRC
|
||||||
|
@ -265,6 +276,10 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC
|
||||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC
|
||||||
|
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||||
|
)
|
||||||
|
|
||||||
set(TEST_SRC
|
set(TEST_SRC
|
||||||
${TEST_LITE_SRC}
|
${TEST_LITE_SRC}
|
||||||
${TEST_MINDDATA_SRC}
|
${TEST_MINDDATA_SRC}
|
||||||
|
@ -278,7 +293,9 @@ set(TEST_SRC
|
||||||
if (SUPPORT_TRAIN)
|
if (SUPPORT_TRAIN)
|
||||||
set(TEST_SRC
|
set(TEST_SRC
|
||||||
${TEST_SRC}
|
${TEST_SRC}
|
||||||
${TEST_DIR}/ut/src/train_test.cc
|
${TEST_CASE_KERNEL_TRAIN_SRC}
|
||||||
|
# ${TEST_DIR}/ut/src/train_test.cc
|
||||||
|
${TEST_DIR}/ut/src/infer_test.cc # temporary
|
||||||
)
|
)
|
||||||
else()
|
else()
|
||||||
set(TEST_SRC
|
set(TEST_SRC
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -25,7 +24,7 @@
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
#include "mindspore/lite/src/ir/tensor.h"
|
#include "mindspore/lite/src/ir/tensor.h"
|
||||||
#include "mindspore/lite/src/lite_kernel.h"
|
#include "mindspore/lite/src/lite_kernel.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class TestActGradFp32 : public mindspore::Common {
|
class TestActGradFp32 : public mindspore::Common {
|
|
@ -21,7 +21,7 @@
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "src/common/file_utils_ext.h"
|
#include "src/common/file_utils_ext.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
|
@ -18,7 +18,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
|
@ -21,8 +21,8 @@
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "src/common/file_utils_ext.h"
|
#include "src/common/file_utils_ext.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h"
|
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
|
|
|
@ -22,8 +22,8 @@
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "src/runtime/kernel/arm/fp32/pooling_grad.h"
|
#include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h"
|
#include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class TestPoolingGradFp32 : public mindspore::Common {
|
class TestPoolingGradFp32 : public mindspore::Common {
|
|
@ -0,0 +1,92 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "src/common/file_utils.h"
|
||||||
|
#include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h"
|
||||||
|
#include "src/kernel_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class TestSoftmaxCrossEntropyFp32 : public mindspore::Common {
|
||||||
|
public:
|
||||||
|
TestSoftmaxCrossEntropyFp32() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
|
||||||
|
// prepare stage
|
||||||
|
SoftmaxCrossEntropyParameter *sce_param = new SoftmaxCrossEntropyParameter();
|
||||||
|
size_t input_size;
|
||||||
|
|
||||||
|
std::string input_path = "./test_data/operators/sce_fp32_1_y_6_4.bin";
|
||||||
|
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||||
|
std::vector<int> dim_y({6, 4});
|
||||||
|
lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y);
|
||||||
|
y_tensor.SetData(input_data);
|
||||||
|
|
||||||
|
std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin";
|
||||||
|
auto ll_labels = reinterpret_cast<int64 *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size));
|
||||||
|
auto labels = new int[6];
|
||||||
|
for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]);
|
||||||
|
|
||||||
|
std::vector<int> dim_l({6});
|
||||||
|
lite::tensor::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l);
|
||||||
|
l_tensor.SetData(labels);
|
||||||
|
|
||||||
|
std::vector<lite::tensor::Tensor *> inputs = {&y_tensor, &l_tensor};
|
||||||
|
|
||||||
|
auto loss = new float[1];
|
||||||
|
std::vector<int> dim_dw({1});
|
||||||
|
lite::tensor::Tensor loss_tensor(TypeId::kNumberTypeFloat32, dim_dw);
|
||||||
|
loss_tensor.SetData(loss);
|
||||||
|
auto grad = new float[24];
|
||||||
|
lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y);
|
||||||
|
grad_tensor.SetData(grad);
|
||||||
|
std::vector<lite::tensor::Tensor *> outputs = {&grad_tensor, &loss_tensor};
|
||||||
|
|
||||||
|
kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy};
|
||||||
|
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||||
|
auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(sce_param), NULL, desc, nullptr);
|
||||||
|
kernel_obj->Run();
|
||||||
|
|
||||||
|
printf("==================total loss=================\n");
|
||||||
|
std::cout << loss[0] << " ," << std::endl;
|
||||||
|
|
||||||
|
printf("==================Testing Grad===============\n");
|
||||||
|
|
||||||
|
std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin";
|
||||||
|
lite::CompareOutput(loss, output_path);
|
||||||
|
|
||||||
|
((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train();
|
||||||
|
kernel_obj->Run();
|
||||||
|
|
||||||
|
printf("==================output data=================\n");
|
||||||
|
for (int i = 0; i < 12; i++) {
|
||||||
|
std::cout << grad[i] << " ,";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin";
|
||||||
|
lite::CompareOutput(grad, grad_path);
|
||||||
|
|
||||||
|
delete sce_param;
|
||||||
|
l_tensor.SetData(NULL);
|
||||||
|
y_tensor.SetData(NULL);
|
||||||
|
MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
Loading…
Reference in New Issue